Skip to content

Commit e4b6d14

Browse files
committed
Adding a print method for counters
Also included new test for defm models, checking that forcing new support calculation when initializing works. Refactored how the defm hasher works; it is now a factory function.
1 parent 8e4c308 commit e4b6d14

File tree

8 files changed

+453
-52
lines changed

8 files changed

+453
-52
lines changed

barry.hpp

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5091,6 +5091,11 @@ class Counter {
50915091
void set_hasher(Hasher_fun_type<Array_Type,Data_Type> fun);
50925092
Hasher_fun_type<Array_Type,Data_Type> get_hasher();
50935093
///@}
5094+
5095+
/**
5096+
* @brief Print a summary of the counter.
5097+
*/
5098+
void print() const;
50945099

50955100
};
50965101

@@ -5189,9 +5194,25 @@ class Counters {
51895194
bool add_dims = true
51905195
);
51915196

5197+
/**
5198+
* @brief Set the hasher function in addition to
5199+
* the individual hasher functions of each counter.
5200+
* @param fun_ A hasher function that will be appended to the
5201+
* hash generated by the individual counters.
5202+
*/
51925203
void add_hash(
51935204
Hasher_fun_type<Array_Type,Data_Type> fun_
51945205
);
5206+
5207+
/**
5208+
* @brief Print a summary of the counters in the set.
5209+
* @param max_length_name Maximum length of the name to be printed.
5210+
* @param max_length_desc Maximum length of the description to be printed.
5211+
*/
5212+
void print(
5213+
size_t max_length_name = 40,
5214+
size_t max_length_desc = 40
5215+
) const;
51955216

51965217
};
51975218

@@ -5322,7 +5343,7 @@ COUNTER_TEMPLATE(std::string, get_name)() const {
53225343
}
53235344

53245345
COUNTER_TEMPLATE(std::string, get_description)() const {
5325-
return this->name;
5346+
return this->desc;
53265347
}
53275348

53285349
COUNTER_TEMPLATE(void, set_name)(std::string new_name) {
@@ -5337,6 +5358,16 @@ COUNTER_TEMPLATE(void, set_hasher)(Hasher_fun_type<Array_Type,Data_Type> fun) {
53375358
hasher_fun = fun;
53385359
}
53395360

5361+
COUNTER_TEMPLATE(void, print)() const {
5362+
5363+
printf_barry("Counter:\n");
5364+
printf_barry(" Name : %s\n", this->get_name().c_str());
5365+
printf_barry(" Description: %s\n", this->get_description().c_str());
5366+
5367+
return;
5368+
5369+
}
5370+
53405371
#define TMP_HASHER_CALL Hasher_fun_type<Array_Type,Data_Type>
53415372
COUNTER_TEMPLATE(TMP_HASHER_CALL, get_hasher)() {
53425373
return hasher_fun;
@@ -5499,6 +5530,45 @@ COUNTERS_TEMPLATE(void, add_hash)(
54995530

55005531
}
55015532

5533+
COUNTERS_TEMPLATE(void, print)(
5534+
size_t max_length_name,
5535+
size_t max_length_desc
5536+
) const {
5537+
5538+
// Iterating through the counters to see the maximum name length
5539+
size_t max_name_length = 0;
5540+
for (const auto & c : data)
5541+
{
5542+
max_name_length = std::max(max_name_length, c.get_name().size());
5543+
}
5544+
5545+
max_name_length = std::min(max_name_length, max_length_name);
5546+
5547+
// Figuring out the format string so it looks nice
5548+
char fmt[100];
5549+
snprintf(fmt, sizeof(fmt), " - %%-%zus : %%s\n", max_name_length);
5550+
5551+
printf_barry("Counters (%zu):\n", this->size());
5552+
for (size_t i = 0u; i < this->size(); ++i)
5553+
{
5554+
// Figuring out the string to print (if needs truncation)
5555+
auto name_to_print = data.at(i).get_name();
5556+
if (name_to_print.size() > max_name_length)
5557+
name_to_print = name_to_print.substr(0, max_name_length - 3) + "...";
5558+
auto desc_to_print = data.at(i).get_description();
5559+
if (desc_to_print.size() > max_length_desc)
5560+
desc_to_print = desc_to_print.substr(0, max_length_desc - 3) + "...";
5561+
5562+
auto c = data.at(i);
5563+
printf_barry(
5564+
fmt,
5565+
// i,
5566+
name_to_print.c_str(),
5567+
desc_to_print.c_str()
5568+
);
5569+
}
5570+
}
5571+
55025572
#undef COUNTER_TYPE
55035573
#undef COUNTER_TEMPLATE_ARGS
55045574
#undef COUNTER_TEMPLATE

defm.hpp

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -527,34 +527,71 @@ inline void defm_motif_parser(
527527
*/
528528
///@{
529529

530+
///@}
530531

532+
/**
533+
* @brief Factory to create a hasher function for DEFMArray
534+
* @param covar_index If >= 0, then the hasher will include the
535+
* covariate at that index as part of the hash.
536+
* @return A hasher function for DEFMArray
537+
*/
538+
inline barry::Hasher_fun_type<DEFMArray, DEFMCounterData>
539+
defm_hasher_factory(int covar_index = -1) {
540+
541+
// With no covariate index, we skip adding that
542+
// layer to the hasher
543+
if (covar_index >= 0)
544+
{
531545

546+
return [covar_index](
547+
const DEFMArray & array,
548+
DEFMCounterData * d
549+
) -> std::vector< double > {
532550

551+
std::vector< double > res;
533552

534-
///@}
553+
// Adding the column feature
554+
for (size_t i = 0u; i < array.nrow(); ++i)
555+
{
556+
res.push_back(array.D()(i, covar_index));
557+
}
535558

536-
/**
537-
* @brief Data for the counters
538-
*
539-
* @details This class is used to store the data for the counters. It is
540-
* used by the `Counters` class.
541-
*
542-
*/
543-
#define MAKE_DEFM_HASHER(hasher,a,cov) \
544-
barry::Hasher_fun_type<DEFMArray, DEFMCounterData> \
545-
hasher = [cov](const DEFMArray & array, DEFMCounterData * d) -> \
546-
std::vector< double > { \
547-
std::vector< double > res; \
548-
/* Adding the column feature */ \
549-
for (size_t i = 0u; i < array.nrow(); ++i) \
550-
res.push_back(array.D()(i, cov)); \
551-
/* Adding the fixed dims */ \
552-
for (size_t i = 0u; i < (array.nrow() - 1); ++i) \
553-
for (size_t j = 0u; j < array.ncol(); ++j) \
554-
res.push_back(array(i, j)); \
555-
return res;\
559+
// Adding the fixed dims
560+
for (size_t i = 0u; i < (array.nrow() - 1); ++i)
561+
{
562+
for (size_t j = 0u; j < array.ncol(); ++j)
563+
{
564+
res.push_back(array(i, j));
565+
}
566+
}
567+
568+
return res;
569+
};
570+
571+
} else {
572+
573+
return [](
574+
const DEFMArray & array,
575+
DEFMCounterData * d
576+
) -> std::vector< double > {
577+
578+
std::vector< double > res;
579+
580+
// Adding the fixed dims
581+
for (size_t i = 0u; i < (array.nrow() - 1); ++i)
582+
{
583+
for (size_t j = 0u; j < array.ncol(); ++j)
584+
{
585+
res.push_back(array(i, j));
586+
}
587+
}
588+
589+
return res;
556590
};
557591

592+
}
593+
594+
}
558595

559596
/**@name Macros for defining counters
560597
*/
@@ -614,7 +651,7 @@ inline void counter_ones(
614651
if (covar_index >= 0)
615652
{
616653

617-
MAKE_DEFM_HASHER(hasher, array, covar_index)
654+
auto hasher = defm_hasher_factory(covar_index);
618655

619656
DEFM_COUNTER_LAMBDA(counter_tmp)
620657
{
@@ -749,7 +786,7 @@ inline void counter_logit_intercept(
749786
return Array.D()(i, data.idx(1u));
750787
};
751788

752-
MAKE_DEFM_HASHER(hasher, array, covar_index)
789+
auto hasher = defm_hasher_factory(covar_index);
753790
bool hasher_added = false;
754791

755792
std::string yname;
@@ -1085,7 +1122,7 @@ inline void counter_transition(
10851122
if (covar_index >= 0)
10861123
{
10871124

1088-
MAKE_DEFM_HASHER(hasher, array, covar_index)
1125+
auto hasher = defm_hasher_factory(covar_index);
10891126

10901127
if (vname == "")
10911128
{
@@ -1225,7 +1262,7 @@ inline void counter_fixed_effect(
12251262
return 0.0;
12261263
};
12271264

1228-
MAKE_DEFM_HASHER(hasher, array, covar_index)
1265+
auto hasher = defm_hasher_factory(covar_index);
12291266

12301267
if (x_names != nullptr)
12311268
vname = x_names->operator[](covar_index);

include/barry/counters-bones.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ class Counter {
9696
void set_hasher(Hasher_fun_type<Array_Type,Data_Type> fun);
9797
Hasher_fun_type<Array_Type,Data_Type> get_hasher();
9898
///@}
99+
100+
/**
101+
* @brief Print a summary of the counter.
102+
*/
103+
void print() const;
99104

100105
};
101106

@@ -194,9 +199,25 @@ class Counters {
194199
bool add_dims = true
195200
);
196201

202+
/**
203+
* @brief Set the hasher function in addition to
204+
* the individual hasher functions of each counter.
205+
* @param fun_ A hasher function that will be appended to the
206+
* hash generated by the individual counters.
207+
*/
197208
void add_hash(
198209
Hasher_fun_type<Array_Type,Data_Type> fun_
199210
);
211+
212+
/**
213+
* @brief Print a summary of the counters in the set.
214+
* @param max_length_name Maximum length of the name to be printed.
215+
* @param max_length_desc Maximum length of the description to be printed.
216+
*/
217+
void print(
218+
size_t max_length_name = 40,
219+
size_t max_length_desc = 40
220+
) const;
200221

201222
};
202223

include/barry/counters-meat.hpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ COUNTER_TEMPLATE(std::string, get_name)() const {
105105
}
106106

107107
COUNTER_TEMPLATE(std::string, get_description)() const {
108-
return this->name;
108+
return this->desc;
109109
}
110110

111111
COUNTER_TEMPLATE(void, set_name)(std::string new_name) {
@@ -120,6 +120,16 @@ COUNTER_TEMPLATE(void, set_hasher)(Hasher_fun_type<Array_Type,Data_Type> fun) {
120120
hasher_fun = fun;
121121
}
122122

123+
COUNTER_TEMPLATE(void, print)() const {
124+
125+
printf_barry("Counter:\n");
126+
printf_barry(" Name : %s\n", this->get_name().c_str());
127+
printf_barry(" Description: %s\n", this->get_description().c_str());
128+
129+
return;
130+
131+
}
132+
123133
#define TMP_HASHER_CALL Hasher_fun_type<Array_Type,Data_Type>
124134
COUNTER_TEMPLATE(TMP_HASHER_CALL, get_hasher)() {
125135
return hasher_fun;
@@ -282,6 +292,45 @@ COUNTERS_TEMPLATE(void, add_hash)(
282292

283293
}
284294

295+
COUNTERS_TEMPLATE(void, print)(
296+
size_t max_length_name,
297+
size_t max_length_desc
298+
) const {
299+
300+
// Iterating through the counters to see the maximum name length
301+
size_t max_name_length = 0;
302+
for (const auto & c : data)
303+
{
304+
max_name_length = std::max(max_name_length, c.get_name().size());
305+
}
306+
307+
max_name_length = std::min(max_name_length, max_length_name);
308+
309+
// Figuring out the format string so it looks nice
310+
char fmt[100];
311+
snprintf(fmt, sizeof(fmt), " - %%-%zus : %%s\n", max_name_length);
312+
313+
printf_barry("Counters (%zu):\n", this->size());
314+
for (size_t i = 0u; i < this->size(); ++i)
315+
{
316+
// Figuring out the string to print (if needs truncation)
317+
auto name_to_print = data.at(i).get_name();
318+
if (name_to_print.size() > max_name_length)
319+
name_to_print = name_to_print.substr(0, max_name_length - 3) + "...";
320+
auto desc_to_print = data.at(i).get_description();
321+
if (desc_to_print.size() > max_length_desc)
322+
desc_to_print = desc_to_print.substr(0, max_length_desc - 3) + "...";
323+
324+
auto c = data.at(i);
325+
printf_barry(
326+
fmt,
327+
// i,
328+
name_to_print.c_str(),
329+
desc_to_print.c_str()
330+
);
331+
}
332+
}
333+
285334
#undef COUNTER_TYPE
286335
#undef COUNTER_TEMPLATE_ARGS
287336
#undef COUNTER_TEMPLATE

0 commit comments

Comments
 (0)