add query compressing
This commit is contained in:
parent
8e869de449
commit
b9f2d78a4f
|
@ -144,6 +144,10 @@ docker run -d \
|
||||||
|
|
||||||
高级RAG
|
高级RAG
|
||||||
|
|
||||||
|
* 查询压缩 √
|
||||||
|
* 查询路由
|
||||||
|
* Re-rank
|
||||||
|
|
||||||
增加搜索引擎(BING、百度)
|
增加搜索引擎(BING、百度)
|
||||||
|
|
||||||
## 截图
|
## 截图
|
||||||
|
|
|
@ -131,4 +131,10 @@ public class BeanConfig {
|
||||||
return ragService;
|
return ragService;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Bean(name = "queryCompressingRagService")
|
||||||
|
public RAGService queryCompressingRagService() {
|
||||||
|
RAGService ragService = new RAGService("adi_advanced_rag_query_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
|
||||||
|
ragService.init();
|
||||||
|
return ragService;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package com.moyz.adi.common.interfaces;
|
package com.moyz.adi.common.interfaces;
|
||||||
|
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
|
import com.moyz.adi.common.service.RAGService;
|
||||||
import com.moyz.adi.common.util.JsonUtil;
|
import com.moyz.adi.common.util.JsonUtil;
|
||||||
import com.moyz.adi.common.util.LocalCache;
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
import com.moyz.adi.common.util.MapDBChatMemoryStore;
|
import com.moyz.adi.common.util.MapDBChatMemoryStore;
|
||||||
|
@ -15,6 +16,10 @@ import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.RetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.query.transformer.CompressingQueryTransformer;
|
||||||
|
import dev.langchain4j.rag.query.transformer.QueryTransformer;
|
||||||
import dev.langchain4j.service.AiServices;
|
import dev.langchain4j.service.AiServices;
|
||||||
import dev.langchain4j.service.TokenStream;
|
import dev.langchain4j.service.TokenStream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -43,7 +48,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
|
|
||||||
private IChatAssistantWithoutMemory chatAssistantWithoutMemory;
|
private IChatAssistantWithoutMemory chatAssistantWithoutMemory;
|
||||||
|
|
||||||
private MapDBChatMemoryStore mapDBChatMemoryStore;
|
private RAGService queryCompressingRagService;
|
||||||
|
|
||||||
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
|
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
|
||||||
this.modelName = modelName;
|
this.modelName = modelName;
|
||||||
|
@ -56,6 +61,11 @@ public abstract class AbstractLLMService<T> {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){
|
||||||
|
queryCompressingRagService = ragService;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检测该service是否可用(不可用的情况通常是没有配置key)
|
* 检测该service是否可用(不可用的情况通常是没有配置key)
|
||||||
*
|
*
|
||||||
|
@ -99,6 +109,15 @@ public abstract class AbstractLLMService<T> {
|
||||||
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
||||||
}
|
}
|
||||||
log.info("sseChat,messageId:{}", params.getMessageId());
|
log.info("sseChat,messageId:{}", params.getMessageId());
|
||||||
|
|
||||||
|
//Query compressing
|
||||||
|
QueryTransformer queryTransformer = new CompressingQueryTransformer(getChatLLM());
|
||||||
|
|
||||||
|
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
||||||
|
.queryTransformer(queryTransformer)
|
||||||
|
.contentRetriever(queryCompressingRagService.buildContentRetriever())
|
||||||
|
.build();
|
||||||
|
|
||||||
//create chat assistant
|
//create chat assistant
|
||||||
if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) {
|
if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) {
|
||||||
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
|
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
|
||||||
|
@ -109,6 +128,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
chatAssistant = AiServices.builder(IChatAssistant.class)
|
chatAssistant = AiServices.builder(IChatAssistant.class)
|
||||||
.streamingChatLanguageModel(getStreamingChatLLM())
|
.streamingChatLanguageModel(getStreamingChatLLM())
|
||||||
.chatMemoryProvider(chatMemoryProvider)
|
.chatMemoryProvider(chatMemoryProvider)
|
||||||
|
.retrievalAugmentor(retrievalAugmentor)
|
||||||
.build();
|
.build();
|
||||||
} else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) {
|
} else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) {
|
||||||
chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class)
|
chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class)
|
||||||
|
|
|
@ -32,7 +32,7 @@ public class Initializer {
|
||||||
private SysConfigService sysConfigService;
|
private SysConfigService sysConfigService;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private RAGService ragService;
|
private RAGService queryCompressingRagService;
|
||||||
|
|
||||||
@PostConstruct
|
@PostConstruct
|
||||||
public void init() {
|
public void init() {
|
||||||
|
@ -49,7 +49,7 @@ public class Initializer {
|
||||||
log.warn("openai service is disabled");
|
log.warn("openai service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : openaiModels) {
|
for (String model : openaiModels) {
|
||||||
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy));
|
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
|
||||||
}
|
}
|
||||||
|
|
||||||
//dashscope
|
//dashscope
|
||||||
|
@ -58,7 +58,7 @@ public class Initializer {
|
||||||
log.warn("dashscope service is disabled");
|
log.warn("dashscope service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : dashscopeModels) {
|
for (String model : dashscopeModels) {
|
||||||
LLMContext.addLLMService(model, new DashScopeLLMService(model));
|
LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
|
||||||
}
|
}
|
||||||
|
|
||||||
//qianfan
|
//qianfan
|
||||||
|
@ -67,7 +67,7 @@ public class Initializer {
|
||||||
log.warn("qianfan service is disabled");
|
log.warn("qianfan service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : qianfanModels) {
|
for (String model : qianfanModels) {
|
||||||
LLMContext.addLLMService(model, new QianFanLLMService(model));
|
LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
|
||||||
}
|
}
|
||||||
|
|
||||||
//ollama
|
//ollama
|
||||||
|
@ -76,16 +76,13 @@ public class Initializer {
|
||||||
log.warn("ollama service is disabled");
|
log.warn("ollama service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : ollamaModels) {
|
for (String model : ollamaModels) {
|
||||||
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model));
|
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
|
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
|
||||||
|
|
||||||
|
|
||||||
//search engine
|
//search engine
|
||||||
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
|
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
|
||||||
|
|
||||||
|
|
||||||
ragService.init();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,8 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.rag.content.retriever.ContentRetriever;
|
||||||
|
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||||
|
@ -32,6 +34,10 @@ import static java.util.stream.Collectors.joining;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RAGService {
|
public class RAGService {
|
||||||
|
|
||||||
|
private static final int MAX_RESULTS = 3;
|
||||||
|
private static final double MIN_SCORE = 0.6;
|
||||||
|
|
||||||
private String dataBaseUrl;
|
private String dataBaseUrl;
|
||||||
|
|
||||||
private String dataBaseUserName;
|
private String dataBaseUserName;
|
||||||
|
@ -91,23 +97,35 @@ public class RAGService {
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
}
|
}
|
||||||
|
|
||||||
private EmbeddingStoreIngestor getEmbeddingStoreIngestor() {
|
/**
|
||||||
|
* 对文档切块、向量化并存储到数据库
|
||||||
|
*
|
||||||
|
* @param document 知识库文档
|
||||||
|
*/
|
||||||
|
public void ingest(Document document) {
|
||||||
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
|
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
|
||||||
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
|
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
|
||||||
.documentSplitter(documentSplitter)
|
.documentSplitter(documentSplitter)
|
||||||
.embeddingModel(embeddingModel)
|
.embeddingModel(embeddingModel)
|
||||||
.embeddingStore(embeddingStore)
|
.embeddingStore(embeddingStore)
|
||||||
.build();
|
.build();
|
||||||
return embeddingStoreIngestor;
|
embeddingStoreIngestor.ingest(document);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 对文档切块并向量化
|
* There are two methods for retrieve documents:
|
||||||
|
* 1. ContentRetriever.retrieve()
|
||||||
|
* 2. retrieveAndCreatePrompt()
|
||||||
*
|
*
|
||||||
* @param document 知识库文档
|
* @return ContentRetriever
|
||||||
*/
|
*/
|
||||||
public void ingest(Document document) {
|
public ContentRetriever buildContentRetriever() {
|
||||||
getEmbeddingStoreIngestor().ingest(document);
|
return EmbeddingStoreContentRetriever.builder()
|
||||||
|
.embeddingStore(embeddingStore)
|
||||||
|
.embeddingModel(embeddingModel)
|
||||||
|
.maxResults(MAX_RESULTS)
|
||||||
|
.minScore(MIN_SCORE)
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue