fix: 上下文理解

This commit is contained in:
moyangzhan 2024-04-12 09:19:17 +08:00
parent ed7296fa55
commit 031b7f66e2
8 changed files with 148 additions and 36 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
*.db
### STS ###
.apt_generated

View File

@ -3,12 +3,15 @@ package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.MapDBChatMemoryStore;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.ChatMeta;
import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
@ -36,6 +39,12 @@ public abstract class AbstractLLMService<T> {
protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel;
private IChatAssistant chatAssistant;
private IChatAssistantWithoutMemory chatAssistantWithoutMemory;
private MapDBChatMemoryStore mapDBChatMemoryStore;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName;
String st = LocalCache.CONFIGS.get(settingName);
@ -89,19 +98,33 @@ public abstract class AbstractLLMService<T> {
log.error("llm service is disabled");
throw new BaseException(B_LLM_SERVICE_DISABLED);
}
log.info("sseChat,messageId:{}", params.getMessageId());
//create chat assistant
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(getStreamingChatLLM());
if (null != params.getChatMemory()) {
serviceBuilder.chatMemory(params.getChatMemory());
if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) {
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(6 + 1)
.chatMemoryStore(MapDBChatMemoryStore.getSingleton())
.build();
chatAssistant = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(getStreamingChatLLM())
.chatMemoryProvider(chatMemoryProvider)
.build();
} else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) {
chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class)
.streamingChatLanguageModel(getStreamingChatLLM())
.build();
}
IChatAssistant chatAssistant = serviceBuilder.build();
TokenStream tokenStream;
if (StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage());
if (StringUtils.isNotBlank(params.getMessageId()) && StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getMessageId(), params.getSystemMessage(), params.getUserMessage());
} else if (StringUtils.isNotBlank(params.getMessageId()) && StringUtils.isBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getMessageId(), params.getUserMessage());
} else if (StringUtils.isBlank(params.getMessageId()) && StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistantWithoutMemory.chat(params.getSystemMessage(), params.getUserMessage());
} else {
tokenStream = chatAssistant.chat(params.getUserMessage());
tokenStream = chatAssistantWithoutMemory.chat(params.getUserMessage());
}
tokenStream
.onNext((content) -> {

View File

@ -1,14 +1,11 @@
package com.moyz.adi.common.interfaces;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import dev.langchain4j.service.*;
public interface IChatAssistant {
@SystemMessage("{{sm}}")
TokenStream chat(@V("sm") String systemMessage, @UserMessage String prompt);
TokenStream chat(@MemoryId String memoryId, @V("sm") String systemMessage, @UserMessage String prompt);
TokenStream chat(@UserMessage String prompt);
TokenStream chat(@MemoryId String memoryId, @UserMessage String prompt);
}

View File

@ -0,0 +1,14 @@
package com.moyz.adi.common.interfaces;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
public interface IChatAssistantWithoutMemory {
@SystemMessage("{{sm}}")
TokenStream chat(@V("sm") String systemMessage, @UserMessage String prompt);
TokenStream chat(@UserMessage String prompt);
}

View File

@ -19,6 +19,7 @@ import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
@ -139,29 +140,33 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
sseAskParams.setSystemMessage(conversation.getAiSystemMessage());
}
//history message
if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable()) && user.getUnderstandContextMsgPairNum() > 0) {
List<ConversationMessage> historyMsgList = this.lambdaQuery()
.eq(ConversationMessage::getUserId, user.getId())
.eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid())
.orderByDesc(ConversationMessage::getConversationId)
.last("limit " + user.getUnderstandContextMsgPairNum() * 2)
.list();
if (!historyMsgList.isEmpty()) {
ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO));
historyMsgList.sort(Comparator.comparing(ConversationMessage::getId));
for (ConversationMessage historyMsg : historyMsgList) {
if (ChatMessageRole.USER.value().equals(historyMsg.getMessageRole())) {
UserMessage userMessage = UserMessage.from(historyMsg.getRemark());
chatMemory.add(userMessage);
} else if (ChatMessageRole.SYSTEM.value().equals(historyMsg.getMessageRole())) {
SystemMessage userMessage = SystemMessage.from(historyMsg.getRemark());
chatMemory.add(userMessage);
}
}
sseAskParams.setChatMemory(chatMemory);
}
if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable())) {
sseAskParams.setMessageId(askReq.getConversationUuid());
}
// List<ConversationMessage> historyMsgList = this.lambdaQuery()
// .eq(ConversationMessage::getUserId, user.getId())
// .eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid())
// .orderByDesc(ConversationMessage::getId)
// .last("limit " + user.getUnderstandContextMsgPairNum() * 2)
// .list();
// if (!historyMsgList.isEmpty()) {
// ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO));
// historyMsgList.sort(Comparator.comparing(ConversationMessage::getId));
// for (ConversationMessage historyMsg : historyMsgList) {
// if (ChatMessageRoleEnum.USER.getValue().equals(historyMsg.getMessageRole())) {
// UserMessage userMessage = UserMessage.from(historyMsg.getRemark());
// chatMemory.add(userMessage);
// } else if (ChatMessageRoleEnum.SYSTEM.getValue().equals(historyMsg.getMessageRole())) {
// SystemMessage systemMessage = SystemMessage.from(historyMsg.getRemark());
// chatMemory.add(systemMessage);
// }else if (ChatMessageRoleEnum.ASSISTANT.getValue().equals(historyMsg.getMessageRole())) {
// AiMessage aiMessage = AiMessage.from(historyMsg.getRemark());
// chatMemory.add(aiMessage);
// }
// }
// sseAskParams.setChatMemory(chatMemory);
// }
// }
}
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);

View File

@ -0,0 +1,59 @@
package com.moyz.adi.common.util;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import lombok.extern.slf4j.Slf4j;
import org.mapdb.DB;
import org.mapdb.DBMaker;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.data.message.ChatMessageDeserializer.messagesFromJson;
import static dev.langchain4j.data.message.ChatMessageSerializer.messagesToJson;
import static org.mapdb.Serializer.STRING;
@Slf4j
public class MapDBChatMemoryStore implements ChatMemoryStore {
public static MapDBChatMemoryStore singleton;
private final DB db = DBMaker.fileDB("chat-memory.db").transactionEnable().make();
private final Map<String, String> map = db.hashMap("messages", STRING, STRING).createOrOpen();
@Override
public List<ChatMessage> getMessages(Object memoryId) {
String json = map.get((String) memoryId);
return messagesFromJson(json);
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
if(messages.size() > 0 && messages.get(0) instanceof AiMessage){
messages.remove(0);
}
String json = messagesToJson(messages);
log.info("updateMessages,{}", json);
map.put((String) memoryId, json);
db.commit();
}
@Override
public void deleteMessages(Object memoryId) {
map.remove((String) memoryId);
db.commit();
}
public static MapDBChatMemoryStore getSingleton() {
if (null == singleton) {
synchronized (MapDBChatMemoryStore.class) {
if (null == singleton) {
singleton = new MapDBChatMemoryStore();
}
}
}
return singleton;
}
}

View File

@ -9,6 +9,8 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Data
public class SseAskParams {
private String messageId;
private User user;
private String regenerateQuestionUuid;

11
pom.xml
View File

@ -130,6 +130,17 @@
<artifactId>velocity-engine-core</artifactId>
<version>2.3</version>
</dependency>
<dependency>
<groupId>org.mapdb</groupId>
<artifactId>mapdb</artifactId>
<version>3.0.9</version>
<exclusions>
<exclusion>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>