Skip to content

Commit 7cd4e94

Browse files
Copilotgvegayon
andauthored
Support Markov order > 1 in defm_motif_parser (#19)
* Initial plan * Add support for order 2+ Markov models in defm_motif_parser - Updated regex pattern to match multiple bracketed groups - Added logic to handle both 2-group (backwards compatible) and m_order+1 group formulas - Created comprehensive test file 16b-defm-counts-with-formulas.cpp - All tests passing (72 assertions in 19 test cases) Co-authored-by: gvegayon <893619+gvegayon@users.noreply.github.com> * Final verification - all tests passing Verified: - All 19 test cases pass with 72 assertions - Order 2 formulas like "{0y1} > {y1} > {0y1}" now work correctly - Backwards compatibility maintained for existing 2-group formulas - counter_generic() integration verified Co-authored-by: gvegayon <893619+gvegayon@users.noreply.github.com> * Restore barry.hpp file * Add comprehensive error testing for defm_motif_parser Added 15 error test cases to cover all throw statements: - Transition with m_order=0 - Column out of range (all modes) - LHS without time when m_order>1 - Row out of range (all modes) - Duplicate terms (all modes) - RHS with non-m_order row - Explicit row index mismatch - Wrong number of groups - Intercept with past event - Invalid formula syntax All tests passing (87 assertions in 19 test cases) Co-authored-by: gvegayon <893619+gvegayon@users.noreply.github.com> * Address PR review comments - Updated function documentation to describe both 2-group and m_order+1 group modes - Clarified comment for 2-group mode behavior - Removed unused arrow_start and arrow_end variables - Improved comment clarity for RHS bracket check - Added safety check for group_idx with explanation about covariate handling - All tests still passing (87 assertions) Co-authored-by: gvegayon <893619+gvegayon@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: gvegayon <893619+gvegayon@users.noreply.github.com>
1 parent 8bd93c5 commit 7cd4e94

File tree

3 files changed

+545
-54
lines changed

3 files changed

+545
-54
lines changed

include/barry/models/defm/formula.hpp

Lines changed: 178 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,24 @@
3131
*
3232
* ## Transition effects
3333
*
34-
* Transition effects can be specified using two sets of curly brackets and
35-
* an greater-than symbol, i.e., `{...} > {...}`. The first set of brackets,
36-
* which we call LHS, can only hold `row id` that are less than `m_order`.
34+
* Transition effects can be specified using curly brackets separated by
35+
* greater-than symbols ('>').
36+
*
37+
* **Two-group mode (backwards compatible):** `{...} > {...}`
38+
* - First group (LHS): Variables at times 0 to m_order-1. When m_order > 1,
39+
* row indices must be explicitly specified.
40+
* - Second group (RHS): Variables at time m_order. Row indices can be omitted
41+
* and will default to m_order.
42+
*
43+
* **Multi-group mode (explicit):** `{...} > {...} > ... > {...}` (m_order+1 groups)
44+
* - Each group corresponds to a specific time point (0, 1, ..., m_order).
45+
* - Row indices can be omitted and will be inferred from group position.
46+
* - If specified, row indices must match the group position.
47+
*
48+
* Examples:
49+
* - Order 1: `{y0_0} > {y0_1}` or `{y0_0} > {y0}` (both valid)
50+
* - Order 2: `{y0_0} > {y0}` (2-group, implicit final time)
51+
* - Order 2: `{y0_0} > {y0_1} > {y0_2}` (3-group, all explicit)
3752
*
3853
*
3954
* @param formula A string specifying the motif formula (see details).
@@ -63,9 +78,10 @@ inline void defm_motif_parser(
6378
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
6479
std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?")
6580
);
81+
// Updated pattern to match one or more bracketed groups separated by '>'
6682
std::regex pattern_transition(
67-
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") +
6883
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
84+
std::string("(\\s*>\\s*\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\})+") +
6985
std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?")
7086
);
7187

@@ -103,79 +119,187 @@ inline void defm_motif_parser(
103119

104120
}
105121

106-
// Will indicate where the arrow is located at
107-
size_t arrow_position = match.position(4u);
122+
// Find all bracketed groups to determine which time point each variable belongs to
123+
std::regex bracket_pattern("\\{[^}]+\\}");
124+
std::vector<std::pair<size_t, size_t>> bracket_ranges; // start, end positions
125+
126+
auto brackets_begin = std::sregex_iterator(formula.begin(), formula.end(), bracket_pattern);
127+
for (auto i = brackets_begin; i != empty; ++i)
128+
bracket_ranges.push_back({i->position(), i->position() + i->length()});
108129

109-
// This pattern will match
110-
std::regex pattern("(0?)y([0-9]+)(_([0-9]+))?");
111-
112-
auto iter = std::sregex_iterator(formula.begin(), formula.end(), pattern);
130+
size_t num_groups = bracket_ranges.size();
113131

114-
for (auto i = iter; i != empty; ++i)
132+
// For backwards compatibility, allow 2 groups (original behavior) or m_order+1 groups (new behavior)
133+
if (num_groups == 2)
115134
{
135+
// Two-group mode (backwards compatible):
136+
// - First group: variables at time 0 to m_order-1 (must have explicit row when m_order > 1)
137+
// - Second group: variables at time m_order (row can be implicit or explicit)
116138

117-
// Baseline position
118-
size_t current_location = i->position(0u);
119-
120-
// First value true/false
121-
bool is_positive = true;
122-
if (i->operator[](1u).str() == "0")
123-
is_positive = false;
139+
// This pattern will match
140+
std::regex pattern("(0?)y([0-9]+)(_([0-9]+))?");
124141

125-
// Variable position
126-
size_t y_col = std::stoul(i->operator[](2u).str());
127-
if (y_col >= y_ncol)
128-
throw std::logic_error("The proposed column is out of range.");
142+
auto iter = std::sregex_iterator(formula.begin(), formula.end(), pattern);
129143

130-
// Time location
131-
size_t y_row;
132-
std::string tmp_str = i->operator[](4u).str();
133-
if (m_order > 1)
144+
for (auto i = iter; i != empty; ++i)
134145
{
135-
// If missing, we replace with the location
136-
if (tmp_str == "")
146+
147+
// Baseline position
148+
size_t current_location = i->position(0u);
149+
150+
// First value true/false
151+
bool is_positive = true;
152+
if (i->operator[](1u).str() == "0")
153+
is_positive = false;
154+
155+
// Variable position
156+
size_t y_col = std::stoul(i->operator[](2u).str());
157+
if (y_col >= y_ncol)
158+
throw std::logic_error("The proposed column is out of range.");
159+
160+
// Time location
161+
size_t y_row;
162+
std::string tmp_str = i->operator[](4u).str();
163+
if (m_order > 1)
137164
{
165+
// If missing, we replace with the location
166+
if (tmp_str == "")
167+
{
138168

139-
if (current_location > arrow_position)
140-
y_row = m_order;
169+
if (current_location >= bracket_ranges[1].first)
170+
y_row = m_order;
171+
else
172+
throw std::logic_error("LHS of transition must specify time when m_order > 1");
173+
174+
} else
175+
y_row = std::stoul(tmp_str);
176+
177+
if (y_row > m_order)
178+
throw std::logic_error("The proposed row is out of range.");
179+
180+
181+
} else {
182+
183+
// If missing, we replace with the location
184+
if (tmp_str != "")
185+
y_row = std::stoul(tmp_str);
141186
else
142-
throw std::logic_error("LHS of transition must specify time when m_order > 1");
187+
y_row = (current_location < bracket_ranges[1].first ? 0u: 1u);
143188

144-
} else
145-
y_row = std::stoul(tmp_str);
189+
}
146190

147-
if (y_row > m_order)
148-
throw std::logic_error("The proposed row is out of range.");
191+
if (selected[y_col * (m_order + 1) + y_row])
192+
throw std::logic_error(
193+
"The term " + i->str() + " shows more than once in the formula.");
194+
195+
// Only variables at time m_order can be in the RHS (second bracketed group)
196+
if ((current_location >= bracket_ranges[1].first) && (y_row != m_order))
197+
throw std::logic_error(
198+
"Only the row " + std::to_string(m_order) +
199+
" can be specified at the RHS of the motif."
200+
);
201+
202+
selected[y_col * (m_order + 1) + y_row] = true;
203+
204+
locations.push_back(y_col * (m_order + 1) + y_row);
205+
signs.push_back(is_positive);
206+
207+
208+
}
209+
}
210+
else if (num_groups == m_order + 1)
211+
{
212+
// New behavior: each group corresponds to a time point (0 to m_order)
213+
214+
// This pattern will match
215+
std::regex pattern("(0?)y([0-9]+)(_([0-9]+))?");
149216

217+
auto iter = std::sregex_iterator(formula.begin(), formula.end(), pattern);
150218

151-
} else {
219+
for (auto i = iter; i != empty; ++i)
220+
{
152221

153-
// If missing, we replace with the location
222+
// Baseline position
223+
size_t current_location = i->position(0u);
224+
225+
// First value true/false
226+
bool is_positive = true;
227+
if (i->operator[](1u).str() == "0")
228+
is_positive = false;
229+
230+
// Variable position
231+
size_t y_col = std::stoul(i->operator[](2u).str());
232+
if (y_col >= y_ncol)
233+
throw std::logic_error("The proposed column is out of range.");
234+
235+
// Determine which bracketed group this variable belongs to
236+
// Note: The regex pattern ensures all variables are within bracketed groups,
237+
// as the pattern only matches content inside brackets. Variables in covariate
238+
// expressions (after 'x') are handled separately by pattern_conditional above.
239+
size_t group_idx = 0;
240+
bool found_group = false;
241+
for (size_t g = 0; g < bracket_ranges.size(); ++g)
242+
{
243+
if (current_location >= bracket_ranges[g].first &&
244+
current_location < bracket_ranges[g].second)
245+
{
246+
group_idx = g;
247+
found_group = true;
248+
break;
249+
}
250+
}
251+
252+
// Safety check: ensure the variable was found in a bracketed group
253+
// This should never happen given the regex, but verify for safety
254+
if (!found_group)
255+
throw std::logic_error(
256+
"Internal error: variable " + i->str() +
257+
" not found within any bracketed group.");
258+
259+
// Time location
260+
size_t y_row;
261+
std::string tmp_str = i->operator[](4u).str();
262+
263+
// If row is explicitly specified, use it; otherwise infer from group position
154264
if (tmp_str != "")
265+
{
155266
y_row = std::stoul(tmp_str);
267+
268+
// Validate that explicit row matches the expected group position
269+
if (y_row != group_idx)
270+
throw std::logic_error(
271+
"Explicit row index " + std::to_string(y_row) +
272+
" does not match the position in the formula (group " +
273+
std::to_string(group_idx) + ").");
274+
}
156275
else
157-
y_row = (current_location < arrow_position ? 0u: 1u);
158-
159-
}
276+
{
277+
// Infer row from bracketed group position
278+
y_row = group_idx;
279+
}
160280

161-
if (selected[y_col * (m_order + 1) + y_row])
162-
throw std::logic_error(
163-
"The term " + i->str() + " shows more than once in the formula.");
281+
if (y_row > m_order)
282+
throw std::logic_error("The proposed row is out of range.");
164283

165-
// Only the end of the chain can be located at position after the
166-
// arrow
167-
if ((current_location > arrow_position) && (y_row != m_order))
168-
throw std::logic_error(
169-
"Only the row " + std::to_string(m_order) +
170-
" can be specified at the RHS of the motif."
171-
);
284+
if (selected[y_col * (m_order + 1) + y_row])
285+
throw std::logic_error(
286+
"The term " + i->str() + " shows more than once in the formula.");
172287

173-
selected[y_col * (m_order + 1) + y_row] = true;
288+
selected[y_col * (m_order + 1) + y_row] = true;
174289

175-
locations.push_back(y_col * (m_order + 1) + y_row);
176-
signs.push_back(is_positive);
177-
290+
locations.push_back(y_col * (m_order + 1) + y_row);
291+
signs.push_back(is_positive);
292+
178293

294+
}
295+
}
296+
else
297+
{
298+
throw std::logic_error(
299+
"For a Markov model of order " + std::to_string(m_order) +
300+
", transition formulas must have either 2 bracketed groups " +
301+
"(for backwards compatibility) or exactly " + std::to_string(m_order + 1) +
302+
" bracketed groups (found " + std::to_string(num_groups) + ").");
179303
}
180304

181305
return;

0 commit comments

Comments
 (0)