add query compressing

This commit is contained in:
moyangzhan 2024-04-27 08:52:07 +08:00
parent 8e869de449
commit b9f2d78a4f
5 changed files with 60 additions and 15 deletions

View File

@ -144,6 +144,10 @@ docker run -d \
高级RAG 高级RAG
* 查询压缩 √
* 查询路由
* Re-rank
增加搜索引擎BING、百度 增加搜索引擎BING、百度
## 截图 ## 截图

View File

@ -131,4 +131,10 @@ public class BeanConfig {
return ragService; return ragService;
} }
@Bean(name = "queryCompressingRagService")
public RAGService queryCompressingRagService() {
RAGService ragService = new RAGService("adi_advanced_rag_query_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
ragService.init();
return ragService;
}
} }

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.interfaces; package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.service.RAGService;
import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.MapDBChatMemoryStore; import com.moyz.adi.common.util.MapDBChatMemoryStore;
@ -15,6 +16,10 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import dev.langchain4j.service.AiServices; import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.TokenStream;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -43,7 +48,7 @@ public abstract class AbstractLLMService<T> {
private IChatAssistantWithoutMemory chatAssistantWithoutMemory; private IChatAssistantWithoutMemory chatAssistantWithoutMemory;
private MapDBChatMemoryStore mapDBChatMemoryStore; private RAGService queryCompressingRagService;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) { public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName; this.modelName = modelName;
@ -56,6 +61,11 @@ public abstract class AbstractLLMService<T> {
return this; return this;
} }
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){
queryCompressingRagService = ragService;
return this;
}
/** /**
* 检测该service是否可用不可用的情况通常是没有配置key * 检测该service是否可用不可用的情况通常是没有配置key
* *
@ -99,6 +109,15 @@ public abstract class AbstractLLMService<T> {
throw new BaseException(B_LLM_SERVICE_DISABLED); throw new BaseException(B_LLM_SERVICE_DISABLED);
} }
log.info("sseChat,messageId:{}", params.getMessageId()); log.info("sseChat,messageId:{}", params.getMessageId());
//Query compressing
QueryTransformer queryTransformer = new CompressingQueryTransformer(getChatLLM());
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(queryCompressingRagService.buildContentRetriever())
.build();
//create chat assistant //create chat assistant
if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) { if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) {
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder() ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
@ -109,6 +128,7 @@ public abstract class AbstractLLMService<T> {
chatAssistant = AiServices.builder(IChatAssistant.class) chatAssistant = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(getStreamingChatLLM()) .streamingChatLanguageModel(getStreamingChatLLM())
.chatMemoryProvider(chatMemoryProvider) .chatMemoryProvider(chatMemoryProvider)
.retrievalAugmentor(retrievalAugmentor)
.build(); .build();
} else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) { } else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) {
chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class) chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class)

View File

@ -32,7 +32,7 @@ public class Initializer {
private SysConfigService sysConfigService; private SysConfigService sysConfigService;
@Resource @Resource
private RAGService ragService; private RAGService queryCompressingRagService;
@PostConstruct @PostConstruct
public void init() { public void init() {
@ -49,7 +49,7 @@ public class Initializer {
log.warn("openai service is disabled"); log.warn("openai service is disabled");
} }
for (String model : openaiModels) { for (String model : openaiModels) {
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy)); LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
} }
//dashscope //dashscope
@ -58,7 +58,7 @@ public class Initializer {
log.warn("dashscope service is disabled"); log.warn("dashscope service is disabled");
} }
for (String model : dashscopeModels) { for (String model : dashscopeModels) {
LLMContext.addLLMService(model, new DashScopeLLMService(model)); LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
} }
//qianfan //qianfan
@ -67,7 +67,7 @@ public class Initializer {
log.warn("qianfan service is disabled"); log.warn("qianfan service is disabled");
} }
for (String model : qianfanModels) { for (String model : qianfanModels) {
LLMContext.addLLMService(model, new QianFanLLMService(model)); LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
} }
//ollama //ollama
@ -76,16 +76,13 @@ public class Initializer {
log.warn("ollama service is disabled"); log.warn("ollama service is disabled");
} }
for (String model : ollamaModels) { for (String model : ollamaModels) {
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model)); LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
} }
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy)); ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
//search engine //search engine
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy)); SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
ragService.init();
} }
} }

View File

@ -13,6 +13,8 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
@ -32,6 +34,10 @@ import static java.util.stream.Collectors.joining;
@Slf4j @Slf4j
public class RAGService { public class RAGService {
private static final int MAX_RESULTS = 3;
private static final double MIN_SCORE = 0.6;
private String dataBaseUrl; private String dataBaseUrl;
private String dataBaseUserName; private String dataBaseUserName;
@ -91,23 +97,35 @@ public class RAGService {
return embeddingStore; return embeddingStore;
} }
private EmbeddingStoreIngestor getEmbeddingStoreIngestor() { /**
* 对文档切块向量化并存储到数据库
*
* @param document 知识库文档
*/
public void ingest(Document document) {
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO)); DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder() EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
.documentSplitter(documentSplitter) .documentSplitter(documentSplitter)
.embeddingModel(embeddingModel) .embeddingModel(embeddingModel)
.embeddingStore(embeddingStore) .embeddingStore(embeddingStore)
.build(); .build();
return embeddingStoreIngestor; embeddingStoreIngestor.ingest(document);
} }
/** /**
* 对文档切块并向量化 * There are two methods for retrieve documents:
* 1. ContentRetriever.retrieve()
* 2. retrieveAndCreatePrompt()
* *
* @param document 知识库文档 * @return ContentRetriever
*/ */
public void ingest(Document document) { public ContentRetriever buildContentRetriever() {
getEmbeddingStoreIngestor().ingest(document); return EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(MAX_RESULTS)
.minScore(MIN_SCORE)
.build();
} }
/** /**