11package org.ivdnt.galahad.evaluation.metrics
22
33import com.fasterxml.jackson.annotation.JsonValue
4+ import org.ivdnt.galahad.annotations.Annotation
45import org.ivdnt.galahad.annotations.Term
56import org.ivdnt.galahad.evaluation.EvaluationEntry
67import org.ivdnt.galahad.evaluation.comparison.LayerComparison
78import org.ivdnt.galahad.evaluation.comparison.TermComparison
9+ import org.ivdnt.galahad.evaluation.metrics.ClsClasses.Companion.toMetrics
810
911class ClsMetrics (
1012 val precision : Float = 0f ,
@@ -28,76 +30,74 @@ class ClsMetrics(
2830}
2931
3032class NewMetric (
33+ val settings : MetricsSettings ,
3134 val grouped : MutableMap <String , ClsClasses > = mutableMapOf(),
32- ) {
33- val classes: ClsClasses
34- get() = grouped.filter { it.key != TermComparison .MISSING_MATCH }.values.reduce { a, b -> a + b }
35-
36- val micro: ClsMetrics get() = classes.metrics
37-
38- val macro: ClsMetrics
39- get() {
40- val validClasses = grouped.filter { it.key != TermComparison .MISSING_MATCH }
41- return validClasses.values.map { it.metrics }
42- .reduce { a, b -> a + b } / validClasses.size
43- }
44- }
35+ val classes : ClsClasses = grouped.filter { it.key != TermComparison .MISSING_MATCH }.values.reduce { a, b -> a + b },
36+ val accuracy : Float = classes.truePositive.count / classes.count.toFloat(),
37+ val macro : ClsMetrics = grouped.values.map { it.metrics }.reduce { a, b -> a + b } / grouped.size,
38+ )
4539
4640class ClsClasses (
4741 var truePositive : EvaluationEntry = EvaluationEntry (),
4842 var falsePositive : EvaluationEntry = EvaluationEntry (),
4943 var falseNegative : EvaluationEntry = EvaluationEntry (),
5044 var noMatch : EvaluationEntry = EvaluationEntry (),
45+ var count : Int = truePositive.count + falseNegative.count + noMatch.count,
46+ var metrics : ClsMetrics = toMetrics(truePositive, falsePositive, falseNegative, noMatch)
5147) {
5248 fun add (other : ClsClasses , truncate : Boolean = true): ClsClasses {
5349 truePositive = EvaluationEntry .add(truePositive, other.truePositive, truncate)
5450 falsePositive = EvaluationEntry .add(falsePositive, other.falsePositive, truncate)
5551 falseNegative = EvaluationEntry .add(falseNegative, other.falseNegative, truncate)
5652 noMatch = EvaluationEntry .add(noMatch, other.noMatch, truncate)
53+ count = truePositive.count + falseNegative.count + noMatch.count
54+ metrics = toMetrics(truePositive, falsePositive, falseNegative, noMatch)
5755 return this
5856 }
5957
6058 operator fun plus (other : ClsClasses ): ClsClasses = ClsClasses (
61- truePositive = EvaluationEntry .from(other.truePositive,truePositive),
62- falsePositive = EvaluationEntry .from(other.falsePositive,falsePositive),
63- falseNegative = EvaluationEntry .from(other.falseNegative,falseNegative),
64- noMatch = EvaluationEntry .from(other.noMatch,noMatch),
59+ truePositive = EvaluationEntry .from(other.truePositive, truePositive),
60+ falsePositive = EvaluationEntry .from(other.falsePositive, falsePositive),
61+ falseNegative = EvaluationEntry .from(other.falseNegative, falseNegative),
62+ noMatch = EvaluationEntry .from(other.noMatch, noMatch),
63+ count = truePositive.count + falseNegative.count + noMatch.count,
64+ metrics = toMetrics(truePositive, falsePositive, falseNegative, noMatch)
6565 )
6666
67- val metrics: ClsMetrics
68- get() {
69- val precision = notNaN(truePositive.count / (truePositive.count + falsePositive.count).toFloat())
70- val recall = notNaN(truePositive.count / (truePositive.count + falsePositive.count + noMatch.count).toFloat())
67+ companion object {
68+ fun notNaN (value : Float ): Float = if (value.isNaN()) 0.0f else value
69+ fun toMetrics (
70+ truePositive : EvaluationEntry ,
71+ falsePositive : EvaluationEntry ,
72+ falseNegative : EvaluationEntry ,
73+ noMatch : EvaluationEntry
74+ ): ClsMetrics {
75+ val tp = truePositive.count.toFloat()
76+ val fp = falsePositive.count.toFloat()
77+ val fn = falseNegative.count.toFloat()
78+ val mm = noMatch.count.toFloat()
79+ val precision = notNaN(tp / (tp + fp))
80+ val recall = notNaN(tp / (tp + fn + mm))
7181 val f1 = notNaN(2.0f * (precision * recall) / (precision + recall))
7282 return ClsMetrics (precision, recall, f1)
7383 }
74-
75- val count: Int
76- get() = truePositive.count + falseNegative.count + noMatch.count
77-
78- companion object {
79- fun notNaN (value : Float ): Float = if (value.isNaN()) 0.0f else value
8084 }
8185}
8286
83-
8487class DocumentMetric (
85- @JsonValue
86- val classesByGroup : MutableMap <String , NewMetric >
88+ @JsonValue val classesByGroup : MutableMap <String , NewMetric >
8789) {
8890 companion object {
89- fun create (layerComparison : LayerComparison ): DocumentMetric = DocumentMetric (
90- buildMap<String , NewMetric > {
91+ fun create (layerComparison : LayerComparison , annotations : Set < Annotation > ): DocumentMetric = DocumentMetric (
92+ buildMap<String , MutableMap < String , ClsClasses > > {
9193 layerComparison.matches.forEach { tc ->
92- METRIC_TYPES .forEach { metricType ->
94+ METRIC_TYPES .filter { annotations.containsAll(it.requiredAnnotations) }. forEach { metricType ->
9395 if (! metricType.filterBy(tc)) return @forEach
9496 val mapsToAdd = mutableListOf<MutableMap <String , ClsClasses >>()
9597 if (tc.hyp == Term .EMPTY ) {
9698 // handle missing match
9799 val cls = ClsClasses (noMatch = EvaluationEntry (1 , mutableListOf (tc)))
98- val group = TermComparison .MISSING_MATCH
99- val classesMap = mutableMapOf (group to cls)
100- mapsToAdd.add(classesMap)
100+ mapsToAdd.add(mutableMapOf (metricType.groupBy(tc.ref) to cls))
101101 } else {
102102 // handle true positive & false negative
103103 val (trueEntry, falseEntry) = truesFalses(tc, metricType::termsEqual)
@@ -115,16 +115,18 @@ class DocumentMetric(
115115 }
116116 for (map in mapsToAdd) {
117117
118- merge(metricType.id, NewMetric (grouped= map)) { oldMetricMap, newMetricMap -> oldMetricMap.apply {
119- this .grouped.merge(newMetricMap.grouped.keys.first(), newMetricMap.grouped.values.first()) { oldCls, newCls ->
120- oldCls.add(newCls)
121- }
122- }
118+ merge(
119+ metricType.id,
120+ map
121+ ) {
122+ a, b ->
123+ a.apply { this .merge(b.keys.first(), b.values.first()) { x, y -> x.add(y) } }
123124 }
125+
124126 }
125127 }
126128 }
127- }.toMutableMap()
129+ }.mapValues { NewMetric ( METRIC_TYPES .first{mt -> mt.id == it.key}, it.value)}. toMutableMap()
128130 )
129131
130132 private fun truesFalses (
0 commit comments