add query compressing

This commit is contained in:
moyangzhan 2024-04-27 08:52:07 +08:00
parent 8e869de449
commit b9f2d78a4f
5 changed files with 60 additions and 15 deletions

View File

@ -144,6 +144,10 @@ docker run -d \
高级RAG
* 查询压缩 √
* 查询路由
* Re-rank
增加搜索引擎BING、百度
## 截图

View File

@ -131,4 +131,10 @@ public class BeanConfig {
return ragService;
}
@Bean(name = "queryCompressingRagService")
public RAGService queryCompressingRagService() {
RAGService ragService = new RAGService("adi_advanced_rag_query_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
ragService.init();
return ragService;
}
}

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.service.RAGService;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.MapDBChatMemoryStore;
@ -15,6 +16,10 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import lombok.extern.slf4j.Slf4j;
@ -43,7 +48,7 @@ public abstract class AbstractLLMService<T> {
private IChatAssistantWithoutMemory chatAssistantWithoutMemory;
private MapDBChatMemoryStore mapDBChatMemoryStore;
private RAGService queryCompressingRagService;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName;
@ -56,6 +61,11 @@ public abstract class AbstractLLMService<T> {
return this;
}
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){
queryCompressingRagService = ragService;
return this;
}
/**
* 检测该service是否可用不可用的情况通常是没有配置key
*
@ -99,6 +109,15 @@ public abstract class AbstractLLMService<T> {
throw new BaseException(B_LLM_SERVICE_DISABLED);
}
log.info("sseChat,messageId:{}", params.getMessageId());
//Query compressing
QueryTransformer queryTransformer = new CompressingQueryTransformer(getChatLLM());
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(queryCompressingRagService.buildContentRetriever())
.build();
//create chat assistant
if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) {
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
@ -109,6 +128,7 @@ public abstract class AbstractLLMService<T> {
chatAssistant = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(getStreamingChatLLM())
.chatMemoryProvider(chatMemoryProvider)
.retrievalAugmentor(retrievalAugmentor)
.build();
} else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) {
chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class)

View File

@ -32,7 +32,7 @@ public class Initializer {
private SysConfigService sysConfigService;
@Resource
private RAGService ragService;
private RAGService queryCompressingRagService;
@PostConstruct
public void init() {
@ -49,7 +49,7 @@ public class Initializer {
log.warn("openai service is disabled");
}
for (String model : openaiModels) {
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy));
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
}
//dashscope
@ -58,7 +58,7 @@ public class Initializer {
log.warn("dashscope service is disabled");
}
for (String model : dashscopeModels) {
LLMContext.addLLMService(model, new DashScopeLLMService(model));
LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
//qianfan
@ -67,7 +67,7 @@ public class Initializer {
log.warn("qianfan service is disabled");
}
for (String model : qianfanModels) {
LLMContext.addLLMService(model, new QianFanLLMService(model));
LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
//ollama
@ -76,16 +76,13 @@ public class Initializer {
log.warn("ollama service is disabled");
}
for (String model : ollamaModels) {
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model));
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
//search engine
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
ragService.init();
}
}

View File

@ -13,6 +13,8 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
@ -32,6 +34,10 @@ import static java.util.stream.Collectors.joining;
@Slf4j
public class RAGService {
private static final int MAX_RESULTS = 3;
private static final double MIN_SCORE = 0.6;
private String dataBaseUrl;
private String dataBaseUserName;
@ -91,23 +97,35 @@ public class RAGService {
return embeddingStore;
}
private EmbeddingStoreIngestor getEmbeddingStoreIngestor() {
/**
* 对文档切块向量化并存储到数据库
*
* @param document 知识库文档
*/
public void ingest(Document document) {
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
.documentSplitter(documentSplitter)
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
return embeddingStoreIngestor;
embeddingStoreIngestor.ingest(document);
}
/**
* 对文档切块并向量化
* There are two methods for retrieve documents:
* 1. ContentRetriever.retrieve()
* 2. retrieveAndCreatePrompt()
*
* @param document 知识库文档
* @return ContentRetriever
*/
public void ingest(Document document) {
getEmbeddingStoreIngestor().ingest(document);
public ContentRetriever buildContentRetriever() {
return EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(MAX_RESULTS)
.minScore(MIN_SCORE)
.build();
}
/**