fix:知识库token计算问题

This commit is contained in:
moyangzhan 2024-03-12 00:41:49 +08:00
parent b1a9cba31e
commit 498e9131b9
18 changed files with 347 additions and 151 deletions

View File

@ -5,11 +5,11 @@
> **该项目如对您有帮助,欢迎点赞** > **该项目如对您有帮助,欢迎点赞**
### 体验网址 ## 体验网址
[http://www.aideepin.com](http://www.aideepin.com/) [http://www.aideepin.com](http://www.aideepin.com/)
### 功能点 ## 功能点
* 注册&登录 * 注册&登录
* 多会话(多角色) * 多会话(多角色)
@ -19,14 +19,14 @@
* 基于大模型的知识库RAG * 基于大模型的知识库RAG
* 多模型随意切换 * 多模型随意切换
### 接入的模型: ## 接入的模型:
* ChatGPT 3.5 * ChatGPT 3.5
* 通义千问 * 通义千问
* 文心一言 * 文心一言
* DALL-E 2 * DALL-E 2
### 技术栈 ## 技术栈
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) 该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
@ -44,9 +44,9 @@ springboot3.0.5
vue3+typescript+pnpm vue3+typescript+pnpm
### 如何部署 ## 如何部署
#### 初始化 ### 初始化
初始化数据库 初始化数据库
@ -65,7 +65,7 @@ update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name
* redis: application-[dev|prod].xml中的spring.data.redis * redis: application-[dev|prod].xml中的spring.data.redis
* mail: application.xml中的spring.mail * mail: application.xml中的spring.mail
#### 编译及运行 ### 编译及运行
* 进入项目 * 进入项目
@ -100,12 +100,12 @@ docker run -d \
aideepin:0.0.1 aideepin:0.0.1
``` ```
### 待办: ## 待办:
增强RAG 增强RAG
### 截图 ## 截图
**AI聊天** **AI聊天**
![1691583184761](image/README/1691583184761.png) ![1691583184761](image/README/1691583184761.png)

View File

@ -5,12 +5,15 @@ import com.moyz.adi.common.dto.QAReq;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService; import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
import com.moyz.adi.common.service.KnowledgeBaseService; import com.moyz.adi.common.service.KnowledgeBaseService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.validation.constraints.Min; import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import org.springframework.http.MediaType;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Tag(name = "知识库问答controller") @Tag(name = "知识库问答controller")
@RequestMapping("/knowledge-base/qa/") @RequestMapping("/knowledge-base/qa/")
@ -25,7 +28,13 @@ public class KnowledgeBaseQAController {
@PostMapping("/ask/{kbUuid}") @PostMapping("/ask/{kbUuid}")
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) { public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName()); return knowledgeBaseService.ask(kbUuid, req.getQuestion(), req.getModelName());
}
@Operation(summary = "流式响应")
@PostMapping(value = "/process/{kbUuid}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter sseAsk(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
return knowledgeBaseService.sseAsk(kbUuid, req);
} }
@GetMapping("/record/search") @GetMapping("/record/search")

View File

@ -31,11 +31,24 @@ public class KnowledgeBaseQaRecord extends BaseEntity {
@TableField("question") @TableField("question")
private String question; private String question;
@Schema(title = "最终提供给LLM的提示词")
@TableField("prompt")
private String prompt;
@Schema(title = "提供给LLM的提示词所消耗的token数量")
@TableField("prompt_tokens")
private Integer promptTokens;
@Schema(title = "答案") @Schema(title = "答案")
@TableField("answer") @TableField("answer")
private String answer; private String answer;
@Schema(title = "答案消耗的token")
@TableField("answer_tokens")
private Integer answerTokens;
@Schema(title = "提问用户id") @Schema(title = "提问用户id")
@TableField("user_id") @TableField("user_id")
private Long userId; private Long userId;
} }

View File

@ -14,6 +14,13 @@ public class RateLimitHelper {
@Resource @Resource
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;
/**
* 按固定时间窗口计算请求次数
*
* @param requestTimesKey redis key
* @param rateLimitConfig 请求频率限制配置
* @return
*/
public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) { public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) {
int requestCountInTimeWindow = 0; int requestCountInTimeWindow = 0;
String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey); String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey);

View File

@ -0,0 +1,95 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.interfaces.TriConsumer;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.concurrent.TimeUnit;
@Slf4j
@Service
public class SSEEmitterHelper {
@Resource
private StringRedisTemplate stringRedisTemplate;
@Resource
private RateLimitHelper rateLimitHelper;
public void process(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
//rate limit by system
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
sendErrorMsg(sseEmitter, "访问太过频繁");
return;
}
//Check: If still waiting response
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
if (StringUtils.isNotBlank(askingVal)) {
sendErrorMsg(sseEmitter, "正在回复中...");
return;
}
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
try {
sseEmitter.send(SseEmitter.event().name("start"));
} catch (IOException e) {
log.error("error", e);
sseEmitter.completeWithError(e);
stringRedisTemplate.delete(askingKey);
return;
}
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
sseEmitter.onCompletion(() -> {
log.info("response complete,uid:{}", user.getId());
});
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
} catch (IOException e) {
log.error("error", e);
} finally {
stringRedisTemplate.delete(askingKey);
}
}
);
new LLMContext(sseAskParams.getModelName()).getLLMService().sseChat(sseAskParams, (response, promptMeta, answerMeta) -> {
try {
consumer.accept((String) response, (PromptMeta) promptMeta, (AnswerMeta) answerMeta);
} catch (Exception e) {
log.error("error:", e);
} finally {
stringRedisTemplate.delete(askingKey);
}
});
}
public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
try {
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
}
}

View File

@ -1,23 +1,27 @@
package com.moyz.adi.common.interfaces; package com.moyz.adi.common.interfaces;
import com.fasterxml.jackson.databind.util.JSONPObject;
import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.ChatMeta; import com.moyz.adi.common.vo.ChatMeta;
import com.moyz.adi.common.vo.QuestionMeta; import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams; import com.moyz.adi.common.vo.SseAskParams;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices; import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.TokenStream;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException; import java.io.IOException;
import java.net.Proxy; import java.net.Proxy;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID; import java.util.UUID;
@Slf4j @Slf4j
@ -32,7 +36,7 @@ public abstract class AbstractLLMService<T> {
protected StreamingChatLanguageModel streamingChatLanguageModel; protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel; protected ChatLanguageModel chatLanguageModel;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy){ public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy) {
this.modelName = modelName; this.modelName = modelName;
this.proxy = proxy; this.proxy = proxy;
String st = LocalCache.CONFIGS.get(settingName); String st = LocalCache.CONFIGS.get(settingName);
@ -66,11 +70,13 @@ public abstract class AbstractLLMService<T> {
protected abstract StreamingChatLanguageModel buildStreamingChatLLM(); protected abstract StreamingChatLanguageModel buildStreamingChatLLM();
public String chat(ChatMessage chatMessage) { protected abstract String parseError(Object error);
return getChatLLM().generate(chatMessage).content().text();
public Response<AiMessage> chat(ChatMessage chatMessage) {
return getChatLLM().generate(chatMessage);
} }
public void sseChat(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) { public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
//create chat assistant //create chat assistant
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class) AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
@ -98,7 +104,7 @@ public abstract class AbstractLLMService<T> {
.onComplete((response) -> { .onComplete((response) -> {
log.info("返回数据结束了:{}", response); log.info("返回数据结束了:{}", response);
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", ""); String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid); PromptMeta questionMeta = new PromptMeta(response.tokenUsage().inputTokenCount(), questionUuid);
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", "")); AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta); ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", ""); String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
@ -116,7 +122,11 @@ public abstract class AbstractLLMService<T> {
.onError((error) -> { .onError((error) -> {
log.error("stream error", error); log.error("stream error", error);
try { try {
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage())); String errorMsg = parseError(error);
if(StringUtils.isBlank(errorMsg)){
errorMsg = error.getMessage();
}
params.getSseEmitter().send(SseEmitter.event().name("error").data(errorMsg));
} catch (IOException e) { } catch (IOException e) {
log.error("sse error", e); log.error("sse error", e);
} }

View File

@ -4,7 +4,6 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.AskReq; import com.moyz.adi.common.dto.AskReq;
import com.moyz.adi.common.entity.Conversation; import com.moyz.adi.common.entity.Conversation;
import com.moyz.adi.common.entity.ConversationMessage; import com.moyz.adi.common.entity.ConversationMessage;
@ -13,15 +12,14 @@ import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ChatMessageRoleEnum;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.QuotaHelper;
import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.helper.SSEEmitterHelper;
import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.mapper.ConversationMessageMapper;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.util.UserUtil;
import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.QuestionMeta; import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams; import com.moyz.adi.common.vo.SseAskParams;
import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.completion.chat.ChatMessageRole;
import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.SystemMessage;
@ -33,17 +31,13 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit;
import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND; import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
@ -56,9 +50,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
@Resource @Resource
private ConversationMessageService _this; private ConversationMessageService _this;
@Resource
private StringRedisTemplate stringRedisTemplate;
@Resource @Resource
private QuotaHelper quotaHelper; private QuotaHelper quotaHelper;
@ -70,51 +61,43 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
private ConversationService conversationService; private ConversationService conversationService;
@Resource @Resource
private RateLimitHelper rateLimitHelper; private SSEEmitterHelper sseEmitterHelper;
public SseEmitter sseAsk(AskReq askReq) { public SseEmitter sseAsk(AskReq askReq) {
SseEmitter sseEmitter = new SseEmitter(); SseEmitter sseEmitter = new SseEmitter();
User user = ThreadContext.getCurrentUser(); _this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
_this.asyncCheckAndPushToClient(sseEmitter, user, askReq);
return sseEmitter; return sseEmitter;
} }
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) { private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) {
try { try {
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
//check 1: still waiting response
if (StringUtils.isNotBlank(askingVal)) {
sendErrorMsg(sseEmitter, "正在回复中...");
return false;
}
//check 2: the conversation has been deleted //check 1: the conversation has been deleted
Conversation delConv = conversationService.lambdaQuery() Conversation delConv = conversationService.lambdaQuery()
.eq(Conversation::getUuid, askReq.getConversationUuid()) .eq(Conversation::getUuid, askReq.getConversationUuid())
.eq(Conversation::getIsDeleted, true) .eq(Conversation::getIsDeleted, true)
.one(); .one();
if (null != delConv) { if (null != delConv) {
sendErrorMsg(sseEmitter, "该对话已经删除"); sseEmitterHelper.sendErrorMsg(sseEmitter, "该对话已经删除");
return false; return false;
} }
//check 3: conversation quota //check 2: conversation quota
Long convsCount = conversationService.lambdaQuery() Long convsCount = conversationService.lambdaQuery()
.eq(Conversation::getUserId, user.getId()) .eq(Conversation::getUserId, user.getId())
.eq(Conversation::getIsDeleted, false) .eq(Conversation::getIsDeleted, false)
.count(); .count();
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM)); long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
if (convsCount >= convsMax) { if (convsCount >= convsMax) {
sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
return false; return false;
} }
//check 4: current user's quota //check 3: current user's quota
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user); ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
if (null != errorMsg) { if (null != errorMsg) {
sendErrorMsg(sseEmitter, errorMsg.getInfo()); sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo());
return false; return false;
} }
} catch (Exception e) { } catch (Exception e) {
@ -125,60 +108,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
return true; return true;
} }
private void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
try {
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
}
@Async @Async
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) { public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
log.info("asyncCheckAndPushToClient,userId:{}", user.getId()); log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
//rate limit by system
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
sendErrorMsg(sseEmitter, "访问太过频繁");
return;
}
//check business rules //check business rules
if (!check(sseEmitter, user, askReq)) { if (!check(sseEmitter, user, askReq)) {
return; return;
} }
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
try {
sseEmitter.send(SseEmitter.event().name("start"));
} catch (IOException e) {
log.error("error", e);
sseEmitter.completeWithError(e);
stringRedisTemplate.delete(askingKey);
return;
}
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
sseEmitter.onCompletion(() -> {
log.info("response complete,uid:{}", user.getId());
});
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
} catch (IOException e) {
log.error("error", e);
} finally {
stringRedisTemplate.delete(askingKey);
}
}
);
SseAskParams sseAskParams = new SseAskParams(); SseAskParams sseAskParams = new SseAskParams();
sseAskParams.setModelName(askReq.getModelName());
String prompt = askReq.getPrompt(); String prompt = askReq.getPrompt();
if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) { if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) {
prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark(); prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark();
@ -222,14 +161,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
} }
new LLMContext(askReq.getModelName()).getLLMService().sseChat(sseAskParams, (response, questionMeta, answerMeta) -> { sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> {
try { _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
_this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta);
} catch (Exception e) {
log.error("error:", e);
} finally {
stringRedisTemplate.delete(askingKey);
}
}); });
} }
@ -245,7 +178,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
@Transactional @Transactional
public void saveAfterAiResponse(User user, AskReq askReq, String response, QuestionMeta questionMeta, AnswerMeta answerMeta) { public void saveAfterAiResponse(User user, AskReq askReq, String response, PromptMeta questionMeta, AnswerMeta answerMeta) {
int secretKeyType = StringUtils.isNotBlank(user.getSecretKey()) ? AdiConstant.SECRET_KEY_TYPE_CUSTOM : AdiConstant.SECRET_KEY_TYPE_SYSTEM; int secretKeyType = StringUtils.isNotBlank(user.getSecretKey()) ? AdiConstant.SECRET_KEY_TYPE_CUSTOM : AdiConstant.SECRET_KEY_TYPE_SYSTEM;
Conversation conversation; Conversation conversation;
@ -294,7 +227,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
private void calcTodayCost(User user, Conversation conversation, QuestionMeta questionMeta, AnswerMeta answerMeta) { private void calcTodayCost(User user, Conversation conversation, PromptMeta questionMeta, AnswerMeta answerMeta) {
int todayTokenCost = questionMeta.getTokens() + answerMeta.getTokens(); int todayTokenCost = questionMeta.getTokens() + answerMeta.getTokens();
try { try {
@ -304,20 +237,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
.set(Conversation::getTokens, conversation.getTokens() + todayTokenCost) .set(Conversation::getTokens, conversation.getTokens() + todayTokenCost)
.update(); .update();
UserDayCost userDayCost = userDayCostService.getTodayCost(user); userDayCostService.appendCostToUser(user, todayTokenCost);
UserDayCost saveOrUpdateInst = new UserDayCost();
if (null == userDayCost) {
saveOrUpdateInst.setUserId(user.getId());
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
saveOrUpdateInst.setTokens(todayTokenCost);
saveOrUpdateInst.setRequests(1);
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
} else {
saveOrUpdateInst.setId(userDayCost.getId());
saveOrUpdateInst.setTokens(userDayCost.getTokens() + todayTokenCost);
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
}
userDayCostService.saveOrUpdate(saveOrUpdateInst);
} catch (Exception e) { } catch (Exception e) {
log.error("calcTodayCost error", e); log.error("calcTodayCost error", e);
} }

View File

@ -41,6 +41,11 @@ public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
.build(); .build();
} }
@Override
protected String parseError(Object error) {
return null;
}
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) { if (StringUtils.isBlank(setting.getApiKey())) {

View File

@ -5,13 +5,17 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper; import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.UUID;
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND; import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
@Slf4j @Slf4j
@ -21,6 +25,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) { public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid); wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false);
if (!ThreadContext.getCurrentUser().getIsAdmin()) { if (!ThreadContext.getCurrentUser().getIsAdmin()) {
wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId()); wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId());
} }
@ -31,6 +36,36 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper); return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
} }
/**
* 创建新的QA记录
*
* @param knowledgeBase 所属的知识库
* @param question 用户的原始问题
* @param prompt 根据{question}生成的最终提示词
* @param promptTokens 提示词消耗的token
* @param answer 答案
* @param answerTokens 答案消耗的token
* @return
*/
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) {
String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
newObj.setKbId(knowledgeBase.getId());
newObj.setKbUuid((knowledgeBase.getUuid()));
newObj.setUuid(uuid);
newObj.setUserId(user.getId());
newObj.setQuestion(question);
newObj.setPrompt(prompt);
newObj.setPromptTokens(promptTokens);
newObj.setAnswer(answer);
newObj.setAnswerTokens(answerTokens);
baseMapper.insert(newObj);
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(KnowledgeBaseQaRecord::getUuid, uuid);
return baseMapper.selectOne(wrapper);
}
public boolean softDelele(String uuid) { public boolean softDelele(String uuid) {
if (ThreadContext.getCurrentUser().getIsAdmin()) { if (ThreadContext.getCurrentUser().getIsAdmin()) {
return ChainWrappers.lambdaUpdateChain(baseMapper) return ChainWrappers.lambdaUpdateChain(baseMapper)

View File

@ -7,29 +7,35 @@ import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.KbEditReq; import com.moyz.adi.common.dto.KbEditReq;
import com.moyz.adi.common.dto.QAReq;
import com.moyz.adi.common.entity.*; import com.moyz.adi.common.entity.*;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.SSEEmitterHelper;
import com.moyz.adi.common.mapper.KnowledgeBaseMapper; import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
import com.moyz.adi.common.util.BizPager; import com.moyz.adi.common.util.BizPager;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.vo.SseAskParams;
import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.parser.TextDocumentParser; import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser; import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser; import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.Response;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.text.MessageFormat; import java.text.MessageFormat;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES; import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
import static com.moyz.adi.common.enums.ErrorEnum.*; import static com.moyz.adi.common.enums.ErrorEnum.*;
@ -39,6 +45,10 @@ import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.load
@Service @Service
public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> { public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> {
@Lazy
@Resource
private KnowledgeBaseService _this;
@Resource @Resource
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;
@ -54,6 +64,12 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
@Resource @Resource
private FileService fileService; private FileService fileService;
@Resource
private SSEEmitterHelper sseEmitterHelper;
@Resource
private UserDayCostService userDayCostService;
public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) { public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) {
String uuid = kbEditReq.getUuid(); String uuid = kbEditReq.getUuid();
KnowledgeBase knowledgeBase = new KnowledgeBase(); KnowledgeBase knowledgeBase = new KnowledgeBase();
@ -184,7 +200,6 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
} }
} }
public boolean softDelete(String uuid) { public boolean softDelete(String uuid) {
checkPrivilege(null, uuid); checkPrivilege(null, uuid);
return ChainWrappers.lambdaUpdateChain(baseMapper) return ChainWrappers.lambdaUpdateChain(baseMapper)
@ -193,8 +208,29 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
.update(); .update();
} }
public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question, String modelName) { public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
checkRequestTimesOrThrow();
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName);
Response<AiMessage> ar = responsePair.getRight();
int inputTokenCount = ar.tokenUsage().inputTokenCount();
int outputTokenCount = ar.tokenUsage().outputTokenCount();
userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount);
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount);
}
public SseEmitter sseAsk(String kbUuid, QAReq req) {
checkRequestTimesOrThrow();
SseEmitter sseEmitter = new SseEmitter();
_this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req);
return sseEmitter;
}
/**
* 知识库问答限额判断
*/
private void checkRequestTimesOrThrow() {
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd")); String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
String askTimes = stringRedisTemplate.opsForValue().get(key); String askTimes = stringRedisTemplate.opsForValue().get(key);
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily"); String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
@ -202,19 +238,24 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
throw new BaseException(A_QA_ASK_LIMIT); throw new BaseException(A_QA_ASK_LIMIT);
} }
stringRedisTemplate.opsForValue().increment(key); stringRedisTemplate.opsForValue().increment(key);
}
@Async
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
KnowledgeBase knowledgeBase = getOrThrow(kbUuid); KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
String answer = ragService.findAnswer(kbUuid, question, modelName);
String uuid = UUID.randomUUID().toString().replace("-", ""); String prompt = ragService.retrieveAndCreatePrompt(kbUuid, req.getQuestion()).text();
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord(); SseAskParams sseAskParams = new SseAskParams();
newObj.setKbId(knowledgeBase.getId()); sseAskParams.setSystemMessage(StringUtils.EMPTY);
newObj.setKbUuid((knowledgeBase.getUuid())); sseAskParams.setSseEmitter(sseEmitter);
newObj.setUuid(uuid); sseAskParams.setUserMessage(prompt);
newObj.setUserId(ThreadContext.getCurrentUserId()); sseAskParams.setModelName(req.getModelName());
newObj.setQuestion(question); sseEmitterHelper.process(user, sseAskParams, (response, promptMeta, answerMeta) -> {
newObj.setAnswer(answer); knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), prompt, promptMeta.getTokens(), response, answerMeta.getTokens());
knowledgeBaseQaRecordService.save(newObj); userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
return knowledgeBaseQaRecordService.lambdaQuery().eq(KnowledgeBaseQaRecord::getUuid, uuid).one(); });
} }
public KnowledgeBase getOrThrow(String kbUuid) { public KnowledgeBase getOrThrow(String kbUuid) {
@ -248,4 +289,5 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
throw new BaseException(A_USER_NOT_AUTH); throw new BaseException(A_USER_NOT_AUTH);
} }
} }
} }

View File

@ -4,7 +4,10 @@ import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.vo.OpenAiSetting; import com.moyz.adi.common.vo.OpenAiSetting;
import com.theokanning.openai.OpenAiError;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiChatModel;
@ -12,6 +15,7 @@ import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import java.net.Proxy; import java.net.Proxy;
import java.time.Duration; import java.time.Duration;
@ -49,6 +53,16 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
return builder.build(); return builder.build();
} }
@Override
protected String parseError(Object error) {
if(error instanceof OpenAiHttpException){
OpenAiHttpException openAiHttpException = (OpenAiHttpException)error;
OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class);
return openAiError.getError().getMessage();
}
return Strings.EMPTY;
}
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) { if (StringUtils.isBlank(setting.getSecretKey())) {

View File

@ -49,4 +49,9 @@ public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
.secretKey(setting.getSecretKey()) .secretKey(setting.getSecretKey())
.build(); .build();
} }
@Override
protected String parseError(Object error) {
return null;
}
} }

View File

@ -1,22 +1,30 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.interfaces.TriConsumer;
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.PromptMeta;
import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters; import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -104,16 +112,7 @@ public class RAGService {
getEmbeddingStoreIngestor().ingest(document); getEmbeddingStoreIngestor().ingest(document);
} }
/** public Prompt retrieveAndCreatePrompt(String kbUuid, String question) {
* 召回并搜索
*
* @param kbUuid 知识库uuid
* @param question 用户的问题
* @param modelName LLM model name
* @return
*/
public String findAnswer(String kbUuid, String question, String modelName) {
// Embed the question // Embed the question
Embedding questionEmbedding = embeddingModel.embed(question).content(); Embedding questionEmbedding = embeddingModel.embed(question).content();
@ -129,10 +128,25 @@ public class RAGService {
.collect(joining("\n\n")); .collect(joining("\n\n"));
if (StringUtils.isBlank(information)) { if (StringUtils.isBlank(information)) {
return StringUtils.EMPTY; return null;
} }
Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
}
return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); /**
* 召回并提问
*
* @param kbUuid 知识库uuid
* @param question 用户的问题
* @param modelName LLM model name
* @return
*/
public Pair<String, Response<AiMessage>> retrieveAndAsk(String kbUuid, String question, String modelName) {
Prompt prompt = retrieveAndCreatePrompt(kbUuid, question);
if (null == prompt) {
return null;
}
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
return new ImmutablePair<>(prompt.text(), response);
} }
} }

View File

@ -18,6 +18,23 @@ import java.util.List;
@Service @Service
public class UserDayCostService extends ServiceImpl<UserDayCostMapper, UserDayCost> { public class UserDayCostService extends ServiceImpl<UserDayCostMapper, UserDayCost> {
public void appendCostToUser(User user, int tokens) {
UserDayCost userDayCost = getTodayCost(user);
UserDayCost saveOrUpdateInst = new UserDayCost();
if (null == userDayCost) {
saveOrUpdateInst.setUserId(user.getId());
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
saveOrUpdateInst.setTokens(tokens);
saveOrUpdateInst.setRequests(1);
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
} else {
saveOrUpdateInst.setId(userDayCost.getId());
saveOrUpdateInst.setTokens(userDayCost.getTokens() + tokens);
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
}
saveOrUpdate(saveOrUpdateInst);
}
public CostStat costStatByUser(long userId) { public CostStat costStatByUser(long userId) {
CostStat result = new CostStat(); CostStat result = new CostStat();

View File

@ -6,6 +6,6 @@ import lombok.Data;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class ChatMeta { public class ChatMeta {
private QuestionMeta question; private PromptMeta question;
private AnswerMeta answer; private AnswerMeta answer;
} }

View File

@ -5,7 +5,7 @@ import lombok.Data;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class QuestionMeta { public class PromptMeta {
private Integer tokens; private Integer tokens;
private String uuid; private String uuid;
} }

View File

@ -21,4 +21,5 @@ public class SseAskParams {
private SseEmitter sseEmitter; private SseEmitter sseEmitter;
private String modelName;
} }

View File

@ -200,11 +200,11 @@ COMMENT ON COLUMN public.adi_prompt.is_deleted IS '0:未删除1已删除';
CREATE TABLE public.adi_sys_config CREATE TABLE public.adi_sys_config
( (
id bigserial primary key, id bigserial primary key,
name character varying(100) DEFAULT ''::character varying NOT NULL, name character varying(100) DEFAULT ''::character varying NOT NULL,
value character varying(1000) DEFAULT ''::character varying NOT NULL, value character varying(1000) DEFAULT ''::character varying NOT NULL,
create_time timestamp DEFAULT localtimestamp NOT NULL, create_time timestamp DEFAULT localtimestamp NOT NULL,
update_time timestamp DEFAULT localtimestamp NOT NULL, update_time timestamp DEFAULT localtimestamp NOT NULL,
is_deleted boolean DEFAULT false NOT NULL is_deleted boolean DEFAULT false NOT NULL
); );
COMMENT ON TABLE public.adi_sys_config IS '系统配置表'; COMMENT ON TABLE public.adi_sys_config IS '系统配置表';
@ -497,7 +497,10 @@ create table adi_knowledge_base_qa_record
kb_id bigint DEFAULT '0'::bigint NOT NULL, kb_id bigint DEFAULT '0'::bigint NOT NULL,
kb_uuid varchar(32) default ''::character varying not null, kb_uuid varchar(32) default ''::character varying not null,
question varchar(1000) default ''::character varying not null, question varchar(1000) default ''::character varying not null,
prompt text default ''::character varying not null,
prompt_tokens integer DEFAULT 0 NOT NULL,
answer text default ''::character varying not null, answer text default ''::character varying not null,
answer_tokens integer DEFAULT 0 NOT NULL,
source_file_ids varchar(500) default ''::character varying not null, source_file_ids varchar(500) default ''::character varying not null,
user_id bigint default '0' NOT NULL, user_id bigint default '0' NOT NULL,
create_time timestamp default CURRENT_TIMESTAMP not null, create_time timestamp default CURRENT_TIMESTAMP not null,
@ -511,10 +514,16 @@ comment on column adi_knowledge_base_qa_record.kb_id is '所属知识库id';
comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid'; comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid';
comment on column adi_knowledge_base_qa_record.question is '问题'; comment on column adi_knowledge_base_qa_record.question is '用户的原始问题';
comment on column adi_knowledge_base_qa_record.prompt is '提供给LLM的提示词';
comment on column adi_knowledge_base_qa_record.prompt_tokens is '提示词消耗的token';
comment on column adi_knowledge_base_qa_record.answer is '答案'; comment on column adi_knowledge_base_qa_record.answer is '答案';
comment on column adi_knowledge_base_qa_record.answer_tokens is '答案消耗的token';
comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开'; comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开';
comment on column adi_knowledge_base_qa_record.user_id is '提问用户id'; comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';