From 498e9131b99bcb717f2c15de6f3a1b22cb95406e Mon Sep 17 00:00:00 2001 From: moyangzhan Date: Tue, 12 Mar 2024 00:41:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E7=9F=A5=E8=AF=86=E5=BA=93token=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 18 +-- .../controller/KnowledgeBaseQAController.java | 11 +- .../common/entity/KnowledgeBaseQaRecord.java | 13 ++ .../adi/common/helper/RateLimitHelper.java | 7 ++ .../adi/common/helper/SSEEmitterHelper.java | 95 +++++++++++++++ .../common/interfaces/AbstractLLMService.java | 26 ++-- .../service/ConversationMessageService.java | 112 +++--------------- .../common/service/DashScopeLLMService.java | 5 + .../service/KnowledgeBaseQaRecordService.java | 35 ++++++ .../common/service/KnowledgeBaseService.java | 76 +++++++++--- .../adi/common/service/OpenAiLLMService.java | 14 +++ .../adi/common/service/QianFanLLMService.java | 5 + .../moyz/adi/common/service/RAGService.java | 40 +++++-- .../common/service/UserDayCostService.java | 17 +++ .../java/com/moyz/adi/common/vo/ChatMeta.java | 2 +- .../vo/{QuestionMeta.java => PromptMeta.java} | 2 +- .../com/moyz/adi/common/vo/SseAskParams.java | 1 + docs/create.sql | 19 ++- 18 files changed, 347 insertions(+), 151 deletions(-) create mode 100644 adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java rename adi-common/src/main/java/com/moyz/adi/common/vo/{QuestionMeta.java => PromptMeta.java} (85%) diff --git a/README.md b/README.md index b2a72c7..cc8641e 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,11 @@ > **该项目如对您有帮助,欢迎点赞** -### 体验网址 +## 体验网址 [http://www.aideepin.com](http://www.aideepin.com/) -### 功能点 +## 功能点 * 注册&登录 * 多会话(多角色) @@ -19,14 +19,14 @@ * 基于大模型的知识库(RAG) * 多模型随意切换 -### 接入的模型: +## 接入的模型: * ChatGPT 3.5 * 通义千问 * 文心一言 * DALL-E 2 -### 技术栈 +## 技术栈 该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) @@ -44,9 +44,9 @@ springboot3.0.5 vue3+typescript+pnpm -### 如何部署 +## 如何部署 -#### 初始化 +### 初始化 初始化数据库 @@ -65,7 +65,7 @@ update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name * redis: application-[dev|prod].xml中的spring.data.redis * mail: application.xml中的spring.mail -#### 编译及运行 +### 编译及运行 * 进入项目 @@ -100,12 +100,12 @@ docker run -d \ aideepin:0.0.1 ``` -### 待办: +## 待办: 增强RAG -### 截图 +## 截图 **AI聊天:** ![1691583184761](image/README/1691583184761.png) diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java index 82c956b..d15fcc7 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java @@ -5,12 +5,15 @@ import com.moyz.adi.common.dto.QAReq; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; import com.moyz.adi.common.service.KnowledgeBaseQaRecordService; import com.moyz.adi.common.service.KnowledgeBaseService; +import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.annotation.Resource; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; +import org.springframework.http.MediaType; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @Tag(name = "知识库问答controller") @RequestMapping("/knowledge-base/qa/") @@ -25,7 +28,13 @@ public class KnowledgeBaseQAController { @PostMapping("/ask/{kbUuid}") public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) { - return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName()); + return knowledgeBaseService.ask(kbUuid, req.getQuestion(), req.getModelName()); + } + + @Operation(summary = "流式响应") + @PostMapping(value = "/process/{kbUuid}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter sseAsk(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) { + return knowledgeBaseService.sseAsk(kbUuid, req); } @GetMapping("/record/search") diff --git a/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseQaRecord.java b/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseQaRecord.java index 3c32c83..c27f3ac 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseQaRecord.java +++ b/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseQaRecord.java @@ -31,11 +31,24 @@ public class KnowledgeBaseQaRecord extends BaseEntity { @TableField("question") private String question; + @Schema(title = "最终提供给LLM的提示词") + @TableField("prompt") + private String prompt; + + @Schema(title = "提供给LLM的提示词所消耗的token数量") + @TableField("prompt_tokens") + private Integer promptTokens; + @Schema(title = "答案") @TableField("answer") private String answer; + @Schema(title = "答案消耗的token") + @TableField("answer_tokens") + private Integer answerTokens; + @Schema(title = "提问用户id") @TableField("user_id") private Long userId; + } diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/RateLimitHelper.java b/adi-common/src/main/java/com/moyz/adi/common/helper/RateLimitHelper.java index a6b527e..bb6967a 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/helper/RateLimitHelper.java +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/RateLimitHelper.java @@ -14,6 +14,13 @@ public class RateLimitHelper { @Resource private StringRedisTemplate stringRedisTemplate; + /** + * 按固定时间窗口计算请求次数 + * + * @param requestTimesKey redis key + * @param rateLimitConfig 请求频率限制配置 + * @return + */ public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) { int requestCountInTimeWindow = 0; String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey); diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java b/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java new file mode 100644 index 0000000..409df59 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java @@ -0,0 +1,95 @@ +package com.moyz.adi.common.helper; + +import com.moyz.adi.common.cosntant.RedisKeyConstant; +import com.moyz.adi.common.entity.User; +import com.moyz.adi.common.interfaces.TriConsumer; +import com.moyz.adi.common.util.LocalCache; +import com.moyz.adi.common.vo.AnswerMeta; +import com.moyz.adi.common.vo.PromptMeta; +import com.moyz.adi.common.vo.SseAskParams; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Service +public class SSEEmitterHelper { + + @Resource + private StringRedisTemplate stringRedisTemplate; + + @Resource + private RateLimitHelper rateLimitHelper; + + public void process(User user, SseAskParams sseAskParams, TriConsumer consumer) { + SseEmitter sseEmitter = sseAskParams.getSseEmitter(); + + //rate limit by system + String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); + if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) { + sendErrorMsg(sseEmitter, "访问太过频繁"); + return; + } + + //Check: If still waiting response + String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); + String askingVal = stringRedisTemplate.opsForValue().get(askingKey); + if (StringUtils.isNotBlank(askingVal)) { + sendErrorMsg(sseEmitter, "正在回复中..."); + return; + } + stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS); + try { + sseEmitter.send(SseEmitter.event().name("start")); + } catch (IOException e) { + log.error("error", e); + sseEmitter.completeWithError(e); + stringRedisTemplate.delete(askingKey); + return; + } + + rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG); + + sseEmitter.onCompletion(() -> { + log.info("response complete,uid:{}", user.getId()); + }); + sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout())); + sseEmitter.onError( + throwable -> { + try { + log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable); + sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage())); + } catch (IOException e) { + log.error("error", e); + } finally { + stringRedisTemplate.delete(askingKey); + } + } + ); + new LLMContext(sseAskParams.getModelName()).getLLMService().sseChat(sseAskParams, (response, promptMeta, answerMeta) -> { + try { + consumer.accept((String) response, (PromptMeta) promptMeta, (AnswerMeta) answerMeta); + } catch (Exception e) { + log.error("error:", e); + } finally { + stringRedisTemplate.delete(askingKey); + } + }); + } + + public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) { + try { + sseEmitter.send(SseEmitter.event().name("error").data(errorMsg)); + } catch (IOException e) { + throw new RuntimeException(e); + } + sseEmitter.complete(); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java index 0173d5d..5f9b85c 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java @@ -1,23 +1,27 @@ package com.moyz.adi.common.interfaces; +import com.fasterxml.jackson.databind.util.JSONPObject; import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.ChatMeta; -import com.moyz.adi.common.vo.QuestionMeta; +import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.SseAskParams; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; import dev.langchain4j.service.AiServices; import dev.langchain4j.service.TokenStream; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Value; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.IOException; import java.net.Proxy; +import java.util.HashMap; +import java.util.Map; import java.util.UUID; @Slf4j @@ -32,7 +36,7 @@ public abstract class AbstractLLMService { protected StreamingChatLanguageModel streamingChatLanguageModel; protected ChatLanguageModel chatLanguageModel; - public AbstractLLMService(String modelName, String settingName, Class clazz, Proxy proxy){ + public AbstractLLMService(String modelName, String settingName, Class clazz, Proxy proxy) { this.modelName = modelName; this.proxy = proxy; String st = LocalCache.CONFIGS.get(settingName); @@ -66,11 +70,13 @@ public abstract class AbstractLLMService { protected abstract StreamingChatLanguageModel buildStreamingChatLLM(); - public String chat(ChatMessage chatMessage) { - return getChatLLM().generate(chatMessage).content().text(); + protected abstract String parseError(Object error); + + public Response chat(ChatMessage chatMessage) { + return getChatLLM().generate(chatMessage); } - public void sseChat(SseAskParams params, TriConsumer consumer) { + public void sseChat(SseAskParams params, TriConsumer consumer) { //create chat assistant AiServices serviceBuilder = AiServices.builder(IChatAssistant.class) @@ -98,7 +104,7 @@ public abstract class AbstractLLMService { .onComplete((response) -> { log.info("返回数据结束了:{}", response); String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", ""); - QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid); + PromptMeta questionMeta = new PromptMeta(response.tokenUsage().inputTokenCount(), questionUuid); AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", "")); ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta); String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", ""); @@ -116,7 +122,11 @@ public abstract class AbstractLLMService { .onError((error) -> { log.error("stream error", error); try { - params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage())); + String errorMsg = parseError(error); + if(StringUtils.isBlank(errorMsg)){ + errorMsg = error.getMessage(); + } + params.getSseEmitter().send(SseEmitter.event().name("error").data(errorMsg)); } catch (IOException e) { log.error("sse error", e); } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java index 00c1705..c2eb480 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java @@ -4,7 +4,6 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.cosntant.AdiConstant; -import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.dto.AskReq; import com.moyz.adi.common.entity.Conversation; import com.moyz.adi.common.entity.ConversationMessage; @@ -13,15 +12,14 @@ import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.exception.BaseException; -import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.QuotaHelper; -import com.moyz.adi.common.helper.RateLimitHelper; +import com.moyz.adi.common.helper.SSEEmitterHelper; import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.vo.AnswerMeta; -import com.moyz.adi.common.vo.QuestionMeta; +import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.SseAskParams; import com.theokanning.openai.completion.chat.ChatMessageRole; import dev.langchain4j.data.message.SystemMessage; @@ -33,17 +31,13 @@ import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.context.annotation.Lazy; -import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import java.io.IOException; -import java.text.MessageFormat; import java.util.Comparator; import java.util.List; -import java.util.concurrent.TimeUnit; import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; @@ -56,9 +50,6 @@ public class ConversationMessageService extends ServiceImpl= convsMax) { - sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); + sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); return false; } - //check 4: current user's quota + //check 3: current user's quota ErrorEnum errorMsg = quotaHelper.checkTextQuota(user); if (null != errorMsg) { - sendErrorMsg(sseEmitter, errorMsg.getInfo()); + sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo()); return false; } } catch (Exception e) { @@ -125,60 +108,16 @@ public class ConversationMessageService extends ServiceImpl { - log.info("response complete,uid:{}", user.getId()); - }); - sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout())); - sseEmitter.onError( - throwable -> { - try { - log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable); - sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage())); - } catch (IOException e) { - log.error("error", e); - } finally { - stringRedisTemplate.delete(askingKey); - } - } - ); SseAskParams sseAskParams = new SseAskParams(); + sseAskParams.setModelName(askReq.getModelName()); String prompt = askReq.getPrompt(); if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) { prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark(); @@ -222,14 +161,8 @@ public class ConversationMessageService extends ServiceImpl { - try { - _this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta); - } catch (Exception e) { - log.error("error:", e); - } finally { - stringRedisTemplate.delete(askingKey); - } + sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> { + _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); }); } @@ -245,7 +178,7 @@ public class ConversationMessageService extends ServiceImpl { .build(); } + @Override + protected String parseError(Object error) { + return null; + } + @Override protected ChatLanguageModel buildChatLLM() { if (StringUtils.isBlank(setting.getApiKey())) { diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseQaRecordService.java b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseQaRecordService.java index f941c82..8fcc8d6 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseQaRecordService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseQaRecordService.java @@ -5,13 +5,17 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.moyz.adi.common.base.ThreadContext; +import com.moyz.adi.common.entity.KnowledgeBase; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; +import com.moyz.adi.common.entity.User; import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; +import java.util.UUID; + import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND; @Slf4j @@ -21,6 +25,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) { LambdaQueryWrapper wrapper = new LambdaQueryWrapper<>(); wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid); + wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false); if (!ThreadContext.getCurrentUser().getIsAdmin()) { wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId()); } @@ -31,6 +36,36 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl(currentPage, pageSize), wrapper); } + /** + * 创建新的QA记录 + * + * @param knowledgeBase 所属的知识库 + * @param question 用户的原始问题 + * @param prompt 根据{question}生成的最终提示词 + * @param promptTokens 提示词消耗的token + * @param answer 答案 + * @param answerTokens 答案消耗的token + * @return + */ + public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) { + String uuid = UUID.randomUUID().toString().replace("-", ""); + KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord(); + newObj.setKbId(knowledgeBase.getId()); + newObj.setKbUuid((knowledgeBase.getUuid())); + newObj.setUuid(uuid); + newObj.setUserId(user.getId()); + newObj.setQuestion(question); + newObj.setPrompt(prompt); + newObj.setPromptTokens(promptTokens); + newObj.setAnswer(answer); + newObj.setAnswerTokens(answerTokens); + baseMapper.insert(newObj); + + LambdaQueryWrapper wrapper = new LambdaQueryWrapper<>(); + wrapper.eq(KnowledgeBaseQaRecord::getUuid, uuid); + return baseMapper.selectOne(wrapper); + } + public boolean softDelele(String uuid) { if (ThreadContext.getCurrentUser().getIsAdmin()) { return ChainWrappers.lambdaUpdateChain(baseMapper) diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseService.java b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseService.java index 5cb1f11..b6ba019 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseService.java @@ -7,29 +7,35 @@ import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.dto.KbEditReq; +import com.moyz.adi.common.dto.QAReq; import com.moyz.adi.common.entity.*; import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.helper.SSEEmitterHelper; import com.moyz.adi.common.mapper.KnowledgeBaseMapper; import com.moyz.adi.common.util.BizPager; import com.moyz.adi.common.util.LocalDateTimeUtil; +import com.moyz.adi.common.vo.SseAskParams; import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.parser.TextDocumentParser; import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser; import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.Response; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.context.annotation.Lazy; import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.text.MessageFormat; import java.time.LocalDateTime; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; +import java.util.*; import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES; import static com.moyz.adi.common.enums.ErrorEnum.*; @@ -39,6 +45,10 @@ import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.load @Service public class KnowledgeBaseService extends ServiceImpl { + @Lazy + @Resource + private KnowledgeBaseService _this; + @Resource private StringRedisTemplate stringRedisTemplate; @@ -54,6 +64,12 @@ public class KnowledgeBaseService extends ServiceImpl> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName); + Response ar = responsePair.getRight(); + int inputTokenCount = ar.tokenUsage().inputTokenCount(); + int outputTokenCount = ar.tokenUsage().outputTokenCount(); + userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount); + return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount); + } + + public SseEmitter sseAsk(String kbUuid, QAReq req) { + checkRequestTimesOrThrow(); + SseEmitter sseEmitter = new SseEmitter(); + _this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req); + return sseEmitter; + } + + /** + * 知识库问答限额判断 + */ + private void checkRequestTimesOrThrow() { String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd")); String askTimes = stringRedisTemplate.opsForValue().get(key); String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily"); @@ -202,19 +238,24 @@ public class KnowledgeBaseService extends ServiceImpl { + knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), prompt, promptMeta.getTokens(), response, answerMeta.getTokens()); + userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens()); + }); } public KnowledgeBase getOrThrow(String kbUuid) { @@ -248,4 +289,5 @@ public class KnowledgeBaseService extends ServiceImpl { return builder.build(); } + @Override + protected String parseError(Object error) { + if(error instanceof OpenAiHttpException){ + OpenAiHttpException openAiHttpException = (OpenAiHttpException)error; + OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class); + return openAiError.getError().getMessage(); + } + return Strings.EMPTY; + } + @Override protected ChatLanguageModel buildChatLLM() { if (StringUtils.isBlank(setting.getSecretKey())) { diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java index 0b45414..16f94b1 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java @@ -49,4 +49,9 @@ public class QianFanLLMService extends AbstractLLMService { .secretKey(setting.getSecretKey()) .build(); } + + @Override + protected String parseError(Object error) { + return null; + } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java index b739480..af31f51 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java @@ -1,22 +1,30 @@ package com.moyz.adi.common.service; import com.moyz.adi.common.helper.LLMContext; +import com.moyz.adi.common.interfaces.TriConsumer; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; +import com.moyz.adi.common.vo.AnswerMeta; +import com.moyz.adi.common.vo.PromptMeta; import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.splitter.DocumentSplitters; import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.openai.OpenAiTokenizer; +import dev.langchain4j.model.output.Response; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -104,16 +112,7 @@ public class RAGService { getEmbeddingStoreIngestor().ingest(document); } - /** - * 召回并搜索 - * - * @param kbUuid 知识库uuid - * @param question 用户的问题 - * @param modelName LLM model name - * @return - */ - public String findAnswer(String kbUuid, String question, String modelName) { - + public Prompt retrieveAndCreatePrompt(String kbUuid, String question) { // Embed the question Embedding questionEmbedding = embeddingModel.embed(question).content(); @@ -129,10 +128,25 @@ public class RAGService { .collect(joining("\n\n")); if (StringUtils.isBlank(information)) { - return StringUtils.EMPTY; + return null; } - Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); + return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); + } - return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); + /** + * 召回并提问 + * + * @param kbUuid 知识库uuid + * @param question 用户的问题 + * @param modelName LLM model name + * @return + */ + public Pair> retrieveAndAsk(String kbUuid, String question, String modelName) { + Prompt prompt = retrieveAndCreatePrompt(kbUuid, question); + if (null == prompt) { + return null; + } + Response response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); + return new ImmutablePair<>(prompt.text(), response); } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/UserDayCostService.java b/adi-common/src/main/java/com/moyz/adi/common/service/UserDayCostService.java index b66b081..b4d470b 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/UserDayCostService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/UserDayCostService.java @@ -18,6 +18,23 @@ import java.util.List; @Service public class UserDayCostService extends ServiceImpl { + public void appendCostToUser(User user, int tokens) { + UserDayCost userDayCost = getTodayCost(user); + UserDayCost saveOrUpdateInst = new UserDayCost(); + if (null == userDayCost) { + saveOrUpdateInst.setUserId(user.getId()); + saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday()); + saveOrUpdateInst.setTokens(tokens); + saveOrUpdateInst.setRequests(1); + saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user)); + } else { + saveOrUpdateInst.setId(userDayCost.getId()); + saveOrUpdateInst.setTokens(userDayCost.getTokens() + tokens); + saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1); + } + saveOrUpdate(saveOrUpdateInst); + } + public CostStat costStatByUser(long userId) { CostStat result = new CostStat(); diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/ChatMeta.java b/adi-common/src/main/java/com/moyz/adi/common/vo/ChatMeta.java index cb2118c..cae05d9 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/vo/ChatMeta.java +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/ChatMeta.java @@ -6,6 +6,6 @@ import lombok.Data; @Data @AllArgsConstructor public class ChatMeta { - private QuestionMeta question; + private PromptMeta question; private AnswerMeta answer; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/QuestionMeta.java b/adi-common/src/main/java/com/moyz/adi/common/vo/PromptMeta.java similarity index 85% rename from adi-common/src/main/java/com/moyz/adi/common/vo/QuestionMeta.java rename to adi-common/src/main/java/com/moyz/adi/common/vo/PromptMeta.java index 6d17070..b2094b4 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/vo/QuestionMeta.java +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/PromptMeta.java @@ -5,7 +5,7 @@ import lombok.Data; @Data @AllArgsConstructor -public class QuestionMeta { +public class PromptMeta { private Integer tokens; private String uuid; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java index 9c79160..a124c08 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java @@ -21,4 +21,5 @@ public class SseAskParams { private SseEmitter sseEmitter; + private String modelName; } diff --git a/docs/create.sql b/docs/create.sql index c428066..c9bc170 100644 --- a/docs/create.sql +++ b/docs/create.sql @@ -200,11 +200,11 @@ COMMENT ON COLUMN public.adi_prompt.is_deleted IS '0:未删除;1:已删除'; CREATE TABLE public.adi_sys_config ( id bigserial primary key, - name character varying(100) DEFAULT ''::character varying NOT NULL, + name character varying(100) DEFAULT ''::character varying NOT NULL, value character varying(1000) DEFAULT ''::character varying NOT NULL, - create_time timestamp DEFAULT localtimestamp NOT NULL, - update_time timestamp DEFAULT localtimestamp NOT NULL, - is_deleted boolean DEFAULT false NOT NULL + create_time timestamp DEFAULT localtimestamp NOT NULL, + update_time timestamp DEFAULT localtimestamp NOT NULL, + is_deleted boolean DEFAULT false NOT NULL ); COMMENT ON TABLE public.adi_sys_config IS '系统配置表'; @@ -497,7 +497,10 @@ create table adi_knowledge_base_qa_record kb_id bigint DEFAULT '0'::bigint NOT NULL, kb_uuid varchar(32) default ''::character varying not null, question varchar(1000) default ''::character varying not null, + prompt text default ''::character varying not null, + prompt_tokens integer DEFAULT 0 NOT NULL, answer text default ''::character varying not null, + answer_tokens integer DEFAULT 0 NOT NULL, source_file_ids varchar(500) default ''::character varying not null, user_id bigint default '0' NOT NULL, create_time timestamp default CURRENT_TIMESTAMP not null, @@ -511,10 +514,16 @@ comment on column adi_knowledge_base_qa_record.kb_id is '所属知识库id'; comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid'; -comment on column adi_knowledge_base_qa_record.question is '问题'; +comment on column adi_knowledge_base_qa_record.question is '用户的原始问题'; + +comment on column adi_knowledge_base_qa_record.prompt is '提供给LLM的提示词'; + +comment on column adi_knowledge_base_qa_record.prompt_tokens is '提示词消耗的token'; comment on column adi_knowledge_base_qa_record.answer is '答案'; +comment on column adi_knowledge_base_qa_record.answer_tokens is '答案消耗的token'; + comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开'; comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';