From b9f2d78a4fb2a1b69d84f6f0aad0dc9379300242 Mon Sep 17 00:00:00 2001 From: moyangzhan Date: Sat, 27 Apr 2024 08:52:07 +0800 Subject: [PATCH] add query compressing --- README.md | 4 +++ .../moyz/adi/common/config/BeanConfig.java | 6 ++++ .../common/interfaces/AbstractLLMService.java | 22 +++++++++++++- .../moyz/adi/common/service/Initializer.java | 13 ++++---- .../moyz/adi/common/service/RAGService.java | 30 +++++++++++++++---- 5 files changed, 60 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 75dca82..8ee93ba 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,10 @@ docker run -d \ 高级RAG +* 查询压缩 √ +* 查询路由 +* Re-rank + 增加搜索引擎(BING、百度) ## 截图 diff --git a/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java b/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java index 6a5b8a1..7591f6d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java +++ b/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java @@ -131,4 +131,10 @@ public class BeanConfig { return ragService; } + @Bean(name = "queryCompressingRagService") + public RAGService queryCompressingRagService() { + RAGService ragService = new RAGService("adi_advanced_rag_query_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword); + ragService.init(); + return ragService; + } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java index 8842491..852c1ea 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java @@ -1,6 +1,7 @@ package com.moyz.adi.common.interfaces; import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.service.RAGService; import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.MapDBChatMemoryStore; @@ -15,6 +16,10 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer; +import dev.langchain4j.rag.query.transformer.QueryTransformer; import dev.langchain4j.service.AiServices; import dev.langchain4j.service.TokenStream; import lombok.extern.slf4j.Slf4j; @@ -43,7 +48,7 @@ public abstract class AbstractLLMService { private IChatAssistantWithoutMemory chatAssistantWithoutMemory; - private MapDBChatMemoryStore mapDBChatMemoryStore; + private RAGService queryCompressingRagService; public AbstractLLMService(String modelName, String settingName, Class clazz) { this.modelName = modelName; @@ -56,6 +61,11 @@ public abstract class AbstractLLMService { return this; } + public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){ + queryCompressingRagService = ragService; + return this; + } + /** * 检测该service是否可用(不可用的情况通常是没有配置key) * @@ -99,6 +109,15 @@ public abstract class AbstractLLMService { throw new BaseException(B_LLM_SERVICE_DISABLED); } log.info("sseChat,messageId:{}", params.getMessageId()); + + //Query compressing + QueryTransformer queryTransformer = new CompressingQueryTransformer(getChatLLM()); + + RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() + .queryTransformer(queryTransformer) + .contentRetriever(queryCompressingRagService.buildContentRetriever()) + .build(); + //create chat assistant if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) { ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder() @@ -109,6 +128,7 @@ public abstract class AbstractLLMService { chatAssistant = AiServices.builder(IChatAssistant.class) .streamingChatLanguageModel(getStreamingChatLLM()) .chatMemoryProvider(chatMemoryProvider) + .retrievalAugmentor(retrievalAugmentor) .build(); } else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) { chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class) diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java index 8fad37a..2005041 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java @@ -32,7 +32,7 @@ public class Initializer { private SysConfigService sysConfigService; @Resource - private RAGService ragService; + private RAGService queryCompressingRagService; @PostConstruct public void init() { @@ -49,7 +49,7 @@ public class Initializer { log.warn("openai service is disabled"); } for (String model : openaiModels) { - LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy)); + LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService)); } //dashscope @@ -58,7 +58,7 @@ public class Initializer { log.warn("dashscope service is disabled"); } for (String model : dashscopeModels) { - LLMContext.addLLMService(model, new DashScopeLLMService(model)); + LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService)); } //qianfan @@ -67,7 +67,7 @@ public class Initializer { log.warn("qianfan service is disabled"); } for (String model : qianfanModels) { - LLMContext.addLLMService(model, new QianFanLLMService(model)); + LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService)); } //ollama @@ -76,16 +76,13 @@ public class Initializer { log.warn("ollama service is disabled"); } for (String model : ollamaModels) { - LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model)); + LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService)); } ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy)); - //search engine SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy)); - - ragService.init(); } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java index 70d8b6a..6cd816d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java @@ -13,6 +13,8 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; @@ -32,6 +34,10 @@ import static java.util.stream.Collectors.joining; @Slf4j public class RAGService { + + private static final int MAX_RESULTS = 3; + private static final double MIN_SCORE = 0.6; + private String dataBaseUrl; private String dataBaseUserName; @@ -91,23 +97,35 @@ public class RAGService { return embeddingStore; } - private EmbeddingStoreIngestor getEmbeddingStoreIngestor() { + /** + * 对文档切块、向量化并存储到数据库 + * + * @param document 知识库文档 + */ + public void ingest(Document document) { DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO)); EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder() .documentSplitter(documentSplitter) .embeddingModel(embeddingModel) .embeddingStore(embeddingStore) .build(); - return embeddingStoreIngestor; + embeddingStoreIngestor.ingest(document); } /** - * 对文档切块并向量化 + * There are two methods for retrieve documents: + * 1. ContentRetriever.retrieve() + * 2. retrieveAndCreatePrompt() * - * @param document 知识库文档 + * @return ContentRetriever */ - public void ingest(Document document) { - getEmbeddingStoreIngestor().ingest(document); + public ContentRetriever buildContentRetriever() { + return EmbeddingStoreContentRetriever.builder() + .embeddingStore(embeddingStore) + .embeddingModel(embeddingModel) + .maxResults(MAX_RESULTS) + .minScore(MIN_SCORE) + .build(); } /**