Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.LongBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -141,6 +142,12 @@ public DocumentCategorizerDL(File model, File vocabulary, File config,

}

/**
* Returns a zero-filled score array with one entry per category if inference fails (the
* failure is logged), so the rest of the {@link DocumentCategorizer} API — e.g.
* {@link #scoreMap}, {@link #sortedScoreMap}, {@link #getBestCategory} — stays safe to call
* after an error instead of throwing on an empty array.
*/
@Override
public double[] categorize(String[] strings) {

Expand Down Expand Up @@ -198,10 +205,11 @@ public double[] categorize(String[] strings) {
return classificationScoringStrategy.score(scores);

} catch (Exception ex) {
logger.error("Unload to perform document classification inference", ex);
logger.error("Unable to perform document classification inference", ex);
}

return new double[] {};
// Sized to the category count so scoreMap()/sortedScoreMap() never index out of bounds.
return new double[categories.size()];

}

Expand Down Expand Up @@ -298,23 +306,13 @@ private List<Tokens> tokenize(final String text) {
// Split the input text into 200 word chunks with 50 overlapping between chunks.
final String[] whitespaceTokenized = text.split("\\s+");

for (int start = 0; start < whitespaceTokenized.length;
start = start + inferenceOptions.getDocumentSplitSize()) {

// 200 word length chunk
// Check the end do don't go past and get a StringIndexOutOfBoundsException
int end = start + inferenceOptions.getDocumentSplitSize();
if (end > whitespaceTokenized.length) {
end = whitespaceTokenized.length;
}

// The group is that subsection of string.
final String group = String.join(" ", Arrays.copyOfRange(whitespaceTokenized, start, end));
for (final int[] range : chunkRanges(whitespaceTokenized.length,
inferenceOptions.getDocumentSplitSize(), inferenceOptions.getSplitOverlapSize())) {

// We want to overlap each chunk by 50 words so scoot back 50 words for the next iteration.
start = start - inferenceOptions.getSplitOverlapSize();
// The group is that subsection of the input.
final String group =
String.join(" ", Arrays.copyOfRange(whitespaceTokenized, range[0], range[1]));

// Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);

final long[] ids = tokenIds(tokens, vocab);
Expand All @@ -333,6 +331,32 @@ private List<Tokens> tokenize(final String text) {

}

/**
* Computes the {@code [start, end)} word-index ranges the input is split into: chunks of
* {@code splitSize} words overlapping by {@code overlapSize}. The loop always advances by
* at least one word, so a misconfigured {@code overlapSize >= splitSize} can neither stall
* the loop nor produce negative indices.
*
* @param length The number of whitespace-separated words.
* @param splitSize The chunk size in words.
* @param overlapSize The overlap between consecutive chunks in words.
* @return The ordered list of {@code [start, end)} ranges; empty when {@code length == 0}.
*/
static List<int[]> chunkRanges(final int length, final int splitSize, final int overlapSize) {
final List<int[]> ranges = new ArrayList<>();
int start = 0;
while (start < length) {
final int end = Math.min(start + splitSize, length);
ranges.add(new int[] {start, end});
if (end == length) {
break;
}
// Overlap by overlapSize words, but always move forward by at least one.
start = Math.max(end - overlapSize, start + 1);
}
return ranges;
}

/**
* Maps tokens to their vocabulary ids.
*
Expand Down Expand Up @@ -366,21 +390,27 @@ static long[] tokenIds(final String[] tokens, final Map<String, Integer> vocab)
* @param input An array of values.
* @return The output array.
*/
private double[] softmax(final float[] input) {
static double[] softmax(final float[] input) {

// Subtract the maximum before exponentiating (numerically stable softmax): exp() of a
// large logit otherwise overflows to +Infinity, yielding NaN scores. Mathematically
// identical to the naive form. Results are kept in double precision throughout.
double max = Double.NEGATIVE_INFINITY;
for (final float value : input) {
max = Math.max(max, value);
}

final double[] t = new double[input.length];
double sum = 0.0;

for (int x = 0; x < input.length; x++) {
double val = Math.exp(input[x]);
final double val = Math.exp(input[x] - max);
sum += val;
t[x] = val;
}

final double[] output = new double[input.length];

for (int x = 0; x < output.length; x++) {
output[x] = (float) (t[x] / sum);
output[x] = t[x] / sum;
}

return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
package opennlp.dl.doccat;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;

import opennlp.tools.tokenize.WordpieceTokenizer;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -57,4 +60,83 @@ void testTokenIdsRejectsTokensMissingFromVocabulary() {
assertTrue(e.getMessage().contains("missing"),
"the error message should name the missing token: " + e.getMessage());
}

@Test
void testSoftmaxIsUniformForEqualLogitsAndSumsToOne() {
final double[] out = DocumentCategorizerDL.softmax(new float[] {0f, 0f, 0f});

assertEquals(3, out.length);
for (final double p : out) {
assertEquals(1.0 / 3.0, p, 1e-12);
}
assertEquals(1.0, out[0] + out[1] + out[2], 1e-12);
}

@Test
void testSoftmaxIsNumericallyStableForLargeLogits() {
// The naive exp(logit) form overflows to +Infinity here and yields NaN; subtracting
// the maximum keeps every value finite and the distribution uniform.
final double[] out = DocumentCategorizerDL.softmax(new float[] {1000f, 1000f, 1000f});

double sum = 0.0;
for (final double p : out) {
assertFalse(Double.isNaN(p) || Double.isInfinite(p),
"softmax must stay finite for large logits");
assertEquals(1.0 / 3.0, p, 1e-9);
sum += p;
}
assertEquals(1.0, sum, 1e-12);
}

@Test
void testSoftmaxMatchesReferenceDistribution() {
// Reference (numpy): softmax([1,2,3]) = [0.09003057, 0.24472847, 0.66524096].
final double[] out = DocumentCategorizerDL.softmax(new float[] {1f, 2f, 3f});

assertEquals(0.09003057, out[0], 1e-6);
assertEquals(0.24472847, out[1], 1e-6);
assertEquals(0.66524096, out[2], 1e-6);
}

@Test
void testChunkRangesSplitsWithOverlap() {
// 210 words, 200-word chunks overlapping by 50 -> [0,200), [150,210).
final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(210, 200, 50);

assertEquals(2, ranges.size());
assertArrayEquals(new int[] {0, 200}, ranges.get(0));
assertArrayEquals(new int[] {150, 210}, ranges.get(1));
}

@Test
void testChunkRangesSingleChunkWhenShorterThanSplit() {
final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(30, 200, 50);

assertEquals(1, ranges.size());
assertArrayEquals(new int[] {0, 30}, ranges.get(0));
}

@Test
void testChunkRangesEmptyForZeroLength() {
assertTrue(DocumentCategorizerDL.chunkRanges(0, 200, 50).isEmpty());
}

@Test
void testChunkRangesAlwaysProgressesForInvalidOverlap() {
// overlap == split would stall forever, and overlap > split would make the start index
// negative, without the forward-progress guard.
for (final int[] cfg : new int[][] {{10, 5, 5}, {8, 3, 10}, {7, 4, 100}}) {
final int length = cfg[0];
final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(length, cfg[1], cfg[2]);

int previousStart = -1;
for (final int[] range : ranges) {
assertTrue(range[0] >= 0, "start must never be negative: " + range[0]);
assertTrue(range[1] >= range[0], "end must be >= start");
assertTrue(range[0] > previousStart, "each chunk must advance the start index");
previousStart = range[0];
}
assertEquals(length, ranges.get(ranges.size() - 1)[1], "last chunk must reach the end");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,34 @@ public void categorize() throws Exception {

}

@Test
public void categorizeReturnsSizedArrayOnFailure() throws Exception {

final File model = new File(getOpennlpDataDir(),
"onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
final File vocab = new File(getOpennlpDataDir(),
"onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");

try (final DocumentCategorizerDL documentCategorizerDL =
new DocumentCategorizerDL(model, vocab, getCategories(),
new AverageClassificationScoringStrategy(), new InferenceOptions())) {

// Empty input drives categorize() down its failure path (strings[0] throws) before any
// inference; it must return zeros sized to the category count, not an empty array.
final double[] scores = documentCategorizerDL.categorize(new String[0]);
Assertions.assertEquals(getCategories().size(), scores.length);
for (final double score : scores) {
Assertions.assertEquals(0.0, score);
}

// The dependent API must stay safe to call on that result rather than indexing past an
// empty array.
Assertions.assertEquals(getCategories().size(),
documentCategorizerDL.scoreMap(new String[0]).size());
}

}

@Test
public void categorizeWithAutomaticLabels() throws Exception {

Expand Down
Loading