+ * Only cosine similarity is used.
+ * Only ivfflat index is used.
+ */
+public class AdiPgVectorEmbeddingStore implements EmbeddingStore {
+
+ private static final Logger log = LoggerFactory.getLogger(AdiPgVectorEmbeddingStore.class);
+
+ private static final Gson GSON = new Gson();
+
+ private final String host;
+ private final Integer port;
+ private final String user;
+ private final String password;
+ private final String database;
+ private final String table;
+
+ /**
+ * All args constructor for PgVectorEmbeddingStore Class
+ *
+ * @param host The database host
+ * @param port The database port
+ * @param user The database user
+ * @param password The database password
+ * @param database The database name
+ * @param table The database table
+ * @param dimension The vector dimension
+ * @param useIndex Should use IVFFlat index
+ * @param indexListSize The IVFFlat number of lists
+ * @param createTable Should create table automatically
+ * @param dropTableFirst Should drop table first, usually for testing
+ */
+ @Builder
+ public AdiPgVectorEmbeddingStore(
+ String host,
+ Integer port,
+ String user,
+ String password,
+ String database,
+ String table,
+ Integer dimension,
+ Boolean useIndex,
+ Integer indexListSize,
+ Boolean createTable,
+ Boolean dropTableFirst) {
+ this.host = ensureNotBlank(host, "host");
+ this.port = ensureGreaterThanZero(port, "port");
+ this.user = ensureNotBlank(user, "user");
+ this.password = ensureNotBlank(password, "password");
+ this.database = ensureNotBlank(database, "database");
+ this.table = ensureNotBlank(table, "table");
+
+ useIndex = getOrDefault(useIndex, false);
+ createTable = getOrDefault(createTable, true);
+ dropTableFirst = getOrDefault(dropTableFirst, false);
+
+ try (Connection connection = setupConnection()) {
+
+ if (dropTableFirst) {
+ connection.createStatement().executeUpdate(String.format("DROP TABLE IF EXISTS %s", table));
+ }
+
+ if (createTable) {
+ connection.createStatement().executeUpdate(String.format(
+ "CREATE TABLE IF NOT EXISTS %s (" +
+ "embedding_id UUID PRIMARY KEY, " +
+ "embedding vector(%s), " +
+ "text TEXT NULL, " +
+ "metadata JSON NULL" +
+ ")",
+ table, ensureGreaterThanZero(dimension, "dimension")));
+ }
+
+ if (useIndex) {
+ final String indexName = table + "_ivfflat_index";
+ connection.createStatement().executeUpdate(String.format(
+ "CREATE INDEX IF NOT EXISTS %s ON %s " +
+ "USING ivfflat (embedding vector_cosine_ops) " +
+ "WITH (lists = %s)",
+ indexName, table, ensureGreaterThanZero(indexListSize, "indexListSize")));
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private Connection setupConnection() throws SQLException {
+ Connection connection = DriverManager.getConnection(
+ String.format("jdbc:postgresql://%s:%s/%s", host, port, database),
+ user,
+ password
+ );
+ connection.createStatement().executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
+ PGvector.addVectorType(connection);
+ return connection;
+ }
+
+ /**
+ * Adds a given embedding to the store.
+ *
+ * @param embedding The embedding to be added to the store.
+ * @return The auto-generated ID associated with the added embedding.
+ */
+ @Override
+ public String add(Embedding embedding) {
+ String id = randomUUID();
+ addInternal(id, embedding, null);
+ return id;
+ }
+
+ /**
+ * Adds a given embedding to the store.
+ *
+ * @param id The unique identifier for the embedding to be added.
+ * @param embedding The embedding to be added to the store.
+ */
+ @Override
+ public void add(String id, Embedding embedding) {
+ addInternal(id, embedding, null);
+ }
+
+ /**
+ * Adds a given embedding and the corresponding content that has been embedded to the store.
+ *
+ * @param embedding The embedding to be added to the store.
+ * @param textSegment Original content that was embedded.
+ * @return The auto-generated ID associated with the added embedding.
+ */
+ @Override
+ public String add(Embedding embedding, TextSegment textSegment) {
+ String id = randomUUID();
+ addInternal(id, embedding, textSegment);
+ return id;
+ }
+
+ /**
+ * Adds multiple embeddings to the store.
+ *
+ * @param embeddings A list of embeddings to be added to the store.
+ * @return A list of auto-generated IDs associated with the added embeddings.
+ */
+ @Override
+ public List addAll(List embeddings) {
+ List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
+ addAllInternal(ids, embeddings, null);
+ return ids;
+ }
+
+ /**
+ * Adds multiple embeddings and their corresponding contents that have been embedded to the store.
+ *
+ * @param embeddings A list of embeddings to be added to the store.
+ * @param embedded A list of original contents that were embedded.
+ * @return A list of auto-generated IDs associated with the added embeddings.
+ */
+ @Override
+ public List addAll(List embeddings, List embedded) {
+ List ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
+ addAllInternal(ids, embeddings, embedded);
+ return ids;
+ }
+
+ /**
+ * Finds the most relevant (closest in space) embeddings to the provided reference embedding.
+ *
+ * @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
+ * @param maxResults The maximum number of embeddings to be returned.
+ * @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive).
+ * Only embeddings with a score of this value or higher will be returned.
+ * @return A list of embedding matches.
+ * Each embedding match includes a relevance score (derivative of cosine distance),
+ * ranging from 0 (not relevant) to 1 (highly relevant).
+ */
+ @Override
+ public List> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
+ List> result = new ArrayList<>();
+ try (Connection connection = setupConnection()) {
+ String referenceVector = Arrays.toString(referenceEmbedding.vector());
+ String query = String.format(
+ "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
+ referenceVector, table, minScore, maxResults);
+ PreparedStatement selectStmt = connection.prepareStatement(query);
+
+ ResultSet resultSet = selectStmt.executeQuery();
+ while (resultSet.next()) {
+ double score = resultSet.getDouble("score");
+ String embeddingId = resultSet.getString("embedding_id");
+
+ PGvector vector = (PGvector) resultSet.getObject("embedding");
+ Embedding embedding = new Embedding(vector.toArray());
+
+ String text = resultSet.getString("text");
+ TextSegment textSegment = null;
+ if (isNotNullOrBlank(text)) {
+ String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}");
+ Type type = new TypeToken