diff --git a/.gitignore b/.gitignore index 549e00a..d056752 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ target/ !.mvn/wrapper/maven-wrapper.jar !**/src/main/**/target/ !**/src/test/**/target/ +*.db ### STS ### .apt_generated diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java index 5611841..be72e0c 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java @@ -3,12 +3,15 @@ package com.moyz.adi.common.interfaces; import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.LocalCache; +import com.moyz.adi.common.util.MapDBChatMemoryStore; import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.ChatMeta; import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.SseAskParams; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.Response; @@ -36,6 +39,12 @@ public abstract class AbstractLLMService { protected StreamingChatLanguageModel streamingChatLanguageModel; protected ChatLanguageModel chatLanguageModel; + private IChatAssistant chatAssistant; + + private IChatAssistantWithoutMemory chatAssistantWithoutMemory; + + private MapDBChatMemoryStore mapDBChatMemoryStore; + public AbstractLLMService(String modelName, String settingName, Class clazz) { this.modelName = modelName; String st = LocalCache.CONFIGS.get(settingName); @@ -89,19 +98,33 @@ public abstract class AbstractLLMService { log.error("llm service is disabled"); throw new BaseException(B_LLM_SERVICE_DISABLED); } + log.info("sseChat,messageId:{}", params.getMessageId()); //create chat assistant - AiServices serviceBuilder = AiServices.builder(IChatAssistant.class) - .streamingChatLanguageModel(getStreamingChatLLM()); - if (null != params.getChatMemory()) { - serviceBuilder.chatMemory(params.getChatMemory()); + if (null == chatAssistant && StringUtils.isNotBlank(params.getMessageId())) { + ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder() + .id(memoryId) + .maxMessages(6 + 1) + .chatMemoryStore(MapDBChatMemoryStore.getSingleton()) + .build(); + chatAssistant = AiServices.builder(IChatAssistant.class) + .streamingChatLanguageModel(getStreamingChatLLM()) + .chatMemoryProvider(chatMemoryProvider) + .build(); + } else if (null == chatAssistantWithoutMemory && StringUtils.isBlank(params.getMessageId())) { + chatAssistantWithoutMemory = AiServices.builder(IChatAssistantWithoutMemory.class) + .streamingChatLanguageModel(getStreamingChatLLM()) + .build(); } - IChatAssistant chatAssistant = serviceBuilder.build(); TokenStream tokenStream; - if (StringUtils.isNotBlank(params.getSystemMessage())) { - tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage()); + if (StringUtils.isNotBlank(params.getMessageId()) && StringUtils.isNotBlank(params.getSystemMessage())) { + tokenStream = chatAssistant.chat(params.getMessageId(), params.getSystemMessage(), params.getUserMessage()); + } else if (StringUtils.isNotBlank(params.getMessageId()) && StringUtils.isBlank(params.getSystemMessage())) { + tokenStream = chatAssistant.chat(params.getMessageId(), params.getUserMessage()); + } else if (StringUtils.isBlank(params.getMessageId()) && StringUtils.isNotBlank(params.getSystemMessage())) { + tokenStream = chatAssistantWithoutMemory.chat(params.getSystemMessage(), params.getUserMessage()); } else { - tokenStream = chatAssistant.chat(params.getUserMessage()); + tokenStream = chatAssistantWithoutMemory.chat(params.getUserMessage()); } tokenStream .onNext((content) -> { diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistant.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistant.java index 77f8457..cdaf604 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistant.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistant.java @@ -1,14 +1,11 @@ package com.moyz.adi.common.interfaces; -import dev.langchain4j.service.SystemMessage; -import dev.langchain4j.service.TokenStream; -import dev.langchain4j.service.UserMessage; -import dev.langchain4j.service.V; +import dev.langchain4j.service.*; public interface IChatAssistant { @SystemMessage("{{sm}}") - TokenStream chat(@V("sm") String systemMessage, @UserMessage String prompt); + TokenStream chat(@MemoryId String memoryId, @V("sm") String systemMessage, @UserMessage String prompt); - TokenStream chat(@UserMessage String prompt); + TokenStream chat(@MemoryId String memoryId, @UserMessage String prompt); } diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistantWithoutMemory.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistantWithoutMemory.java new file mode 100644 index 0000000..fc69780 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/IChatAssistantWithoutMemory.java @@ -0,0 +1,14 @@ +package com.moyz.adi.common.interfaces; + +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.V; + +public interface IChatAssistantWithoutMemory { + + @SystemMessage("{{sm}}") + TokenStream chat(@V("sm") String systemMessage, @UserMessage String prompt); + + TokenStream chat(@UserMessage String prompt); +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java index 28c1278..aa3797d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java @@ -19,6 +19,7 @@ import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.SseAskParams; import com.theokanning.openai.completion.chat.ChatMessageRole; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; @@ -139,29 +140,33 @@ public class ConversationMessageService extends ServiceImpl 0) { - List historyMsgList = this.lambdaQuery() - .eq(ConversationMessage::getUserId, user.getId()) - .eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid()) - .orderByDesc(ConversationMessage::getConversationId) - .last("limit " + user.getUnderstandContextMsgPairNum() * 2) - .list(); - if (!historyMsgList.isEmpty()) { - ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO)); - historyMsgList.sort(Comparator.comparing(ConversationMessage::getId)); - for (ConversationMessage historyMsg : historyMsgList) { - if (ChatMessageRole.USER.value().equals(historyMsg.getMessageRole())) { - UserMessage userMessage = UserMessage.from(historyMsg.getRemark()); - chatMemory.add(userMessage); - } else if (ChatMessageRole.SYSTEM.value().equals(historyMsg.getMessageRole())) { - SystemMessage userMessage = SystemMessage.from(historyMsg.getRemark()); - chatMemory.add(userMessage); - } - } - sseAskParams.setChatMemory(chatMemory); - } - + if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable())) { + sseAskParams.setMessageId(askReq.getConversationUuid()); } +// List historyMsgList = this.lambdaQuery() +// .eq(ConversationMessage::getUserId, user.getId()) +// .eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid()) +// .orderByDesc(ConversationMessage::getId) +// .last("limit " + user.getUnderstandContextMsgPairNum() * 2) +// .list(); +// if (!historyMsgList.isEmpty()) { +// ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO)); +// historyMsgList.sort(Comparator.comparing(ConversationMessage::getId)); +// for (ConversationMessage historyMsg : historyMsgList) { +// if (ChatMessageRoleEnum.USER.getValue().equals(historyMsg.getMessageRole())) { +// UserMessage userMessage = UserMessage.from(historyMsg.getRemark()); +// chatMemory.add(userMessage); +// } else if (ChatMessageRoleEnum.SYSTEM.getValue().equals(historyMsg.getMessageRole())) { +// SystemMessage systemMessage = SystemMessage.from(historyMsg.getRemark()); +// chatMemory.add(systemMessage); +// }else if (ChatMessageRoleEnum.ASSISTANT.getValue().equals(historyMsg.getMessageRole())) { +// AiMessage aiMessage = AiMessage.from(historyMsg.getRemark()); +// chatMemory.add(aiMessage); +// } +// } +// sseAskParams.setChatMemory(chatMemory); +// } +// } } sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> { _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); diff --git a/adi-common/src/main/java/com/moyz/adi/common/util/MapDBChatMemoryStore.java b/adi-common/src/main/java/com/moyz/adi/common/util/MapDBChatMemoryStore.java new file mode 100644 index 0000000..a5337a3 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/util/MapDBChatMemoryStore.java @@ -0,0 +1,59 @@ +package com.moyz.adi.common.util; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import lombok.extern.slf4j.Slf4j; +import org.mapdb.DB; +import org.mapdb.DBMaker; + +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.data.message.ChatMessageDeserializer.messagesFromJson; +import static dev.langchain4j.data.message.ChatMessageSerializer.messagesToJson; +import static org.mapdb.Serializer.STRING; + +@Slf4j +public class MapDBChatMemoryStore implements ChatMemoryStore { + + public static MapDBChatMemoryStore singleton; + + private final DB db = DBMaker.fileDB("chat-memory.db").transactionEnable().make(); + + private final Map map = db.hashMap("messages", STRING, STRING).createOrOpen(); + + @Override + public List getMessages(Object memoryId) { + String json = map.get((String) memoryId); + return messagesFromJson(json); + } + + @Override + public void updateMessages(Object memoryId, List messages) { + if(messages.size() > 0 && messages.get(0) instanceof AiMessage){ + messages.remove(0); + } + String json = messagesToJson(messages); + log.info("updateMessages,{}", json); + map.put((String) memoryId, json); + db.commit(); + } + + @Override + public void deleteMessages(Object memoryId) { + map.remove((String) memoryId); + db.commit(); + } + + public static MapDBChatMemoryStore getSingleton() { + if (null == singleton) { + synchronized (MapDBChatMemoryStore.class) { + if (null == singleton) { + singleton = new MapDBChatMemoryStore(); + } + } + } + return singleton; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java index a124c08..c88f20d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java @@ -9,6 +9,8 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @Data public class SseAskParams { + private String messageId; + private User user; private String regenerateQuestionUuid; diff --git a/pom.xml b/pom.xml index 5999094..ac05f46 100644 --- a/pom.xml +++ b/pom.xml @@ -130,6 +130,17 @@ velocity-engine-core 2.3 + + org.mapdb + mapdb + 3.0.9 + + + org.jetbrains.kotlin + kotlin-stdlib + + + dev.langchain4j langchain4j