Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import javax.validation.constraints.NotNull;
import java.text.BreakIterator;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

Expand Down Expand Up @@ -60,34 +61,51 @@ public void run() {

final String cleanContent = String.join(SPACE, this.content.trim().split("\\s+"));
final int splitAtTokens = embeddingsAPI.config.getConfigInteger(AppKeys.EMBEDDINGS_SPLIT_AT_TOKENS);
final int overlapTokens = 50;

// split into sentences
// split into sentences, carrying a 50-token overlap into each new chunk
final BreakIterator iterator = BreakIterator.getSentenceInstance(Locale.getDefault());
final StringBuilder buffer = new StringBuilder();
iterator.setText(cleanContent);
int start = iterator.first();

final List<String> sentences = new ArrayList<>();
final List<Integer> tokenCounts = new ArrayList<>();
int totalTokens = 0;
int start = iterator.first();

for (int end = iterator.next(); end != BreakIterator.DONE; start = end, end = iterator.next()) {
final String sentence = cleanContent.substring(start, end);
final String sentence = cleanContent.substring(start, end).trim();
final int tokenCount = EncodingUtil.get()
.getEncoding()
.map(encoding -> encoding.countTokens(sentence))
.orElse(0);
sentences.add(sentence);
tokenCounts.add(tokenCount);
totalTokens += tokenCount;

if (totalTokens < splitAtTokens) {
buffer.append(sentence.trim()).append(SPACE);
} else {
saveEmbedding(buffer.toString());
buffer.setLength(0);
buffer.append(sentence.trim()).append(SPACE);
totalTokens = tokenCount;
if (totalTokens >= splitAtTokens) {
saveEmbedding(String.join(SPACE, sentences));

// retain trailing sentences totalling ~overlapTokens for the next chunk
int overlapStart = sentences.size();
int overlapCount = 0;
while (overlapStart > 0 && overlapCount < overlapTokens) {
overlapStart--;
overlapCount += tokenCounts.get(overlapStart);
}

final List<String> overlap = new ArrayList<>(sentences.subList(overlapStart, sentences.size()));
final List<Integer> overlapCounts = new ArrayList<>(tokenCounts.subList(overlapStart, tokenCounts.size()));
sentences.clear();
tokenCounts.clear();
sentences.addAll(overlap);
tokenCounts.addAll(overlapCounts);
totalTokens = overlapCount;
}
}

if (buffer.toString().split("\\s+").length > 0) {
if (!sentences.isEmpty()) {
AppConfig.debugLogger(embeddingsAPI.config, this.getClass(), () -> String.format("Saving embeddings for contentlet ID '%s'", this.contentlet.getIdentifier()));
this.saveEmbedding(buffer.toString());
this.saveEmbedding(String.join(SPACE, sentences));
AppConfig.debugLogger(embeddingsAPI.config, this.getClass(), () -> String.format("Embeddings for contentlet ID '%s' were saved", this.contentlet.getIdentifier()));
}
} catch (final Exception e) {
Expand Down
Loading