fix:知识库token计算问题
This commit is contained in:
parent
b1a9cba31e
commit
498e9131b9
18
README.md
18
README.md
|
@ -5,11 +5,11 @@
|
|||
|
||||
> **该项目如对您有帮助,欢迎点赞**
|
||||
|
||||
### 体验网址
|
||||
## 体验网址
|
||||
|
||||
[http://www.aideepin.com](http://www.aideepin.com/)
|
||||
|
||||
### 功能点
|
||||
## 功能点
|
||||
|
||||
* 注册&登录
|
||||
* 多会话(多角色)
|
||||
|
@ -19,14 +19,14 @@
|
|||
* 基于大模型的知识库(RAG)
|
||||
* 多模型随意切换
|
||||
|
||||
### 接入的模型:
|
||||
## 接入的模型:
|
||||
|
||||
* ChatGPT 3.5
|
||||
* 通义千问
|
||||
* 文心一言
|
||||
* DALL-E 2
|
||||
|
||||
### 技术栈
|
||||
## 技术栈
|
||||
|
||||
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
|
||||
|
||||
|
@ -44,9 +44,9 @@ springboot3.0.5
|
|||
|
||||
vue3+typescript+pnpm
|
||||
|
||||
### 如何部署
|
||||
## 如何部署
|
||||
|
||||
#### 初始化
|
||||
### 初始化
|
||||
|
||||
初始化数据库
|
||||
|
||||
|
@ -65,7 +65,7 @@ update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name
|
|||
* redis: application-[dev|prod].xml中的spring.data.redis
|
||||
* mail: application.xml中的spring.mail
|
||||
|
||||
#### 编译及运行
|
||||
### 编译及运行
|
||||
|
||||
* 进入项目
|
||||
|
||||
|
@ -100,12 +100,12 @@ docker run -d \
|
|||
aideepin:0.0.1
|
||||
```
|
||||
|
||||
### 待办:
|
||||
## 待办:
|
||||
|
||||
增强RAG
|
||||
|
||||
|
||||
### 截图
|
||||
## 截图
|
||||
|
||||
**AI聊天:**
|
||||
![1691583184761](image/README/1691583184761.png)
|
||||
|
|
|
@ -5,12 +5,15 @@ import com.moyz.adi.common.dto.QAReq;
|
|||
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
||||
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
|
||||
import com.moyz.adi.common.service.KnowledgeBaseService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.validation.constraints.Min;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
@Tag(name = "知识库问答controller")
|
||||
@RequestMapping("/knowledge-base/qa/")
|
||||
|
@ -25,7 +28,13 @@ public class KnowledgeBaseQAController {
|
|||
|
||||
@PostMapping("/ask/{kbUuid}")
|
||||
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
|
||||
return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName());
|
||||
return knowledgeBaseService.ask(kbUuid, req.getQuestion(), req.getModelName());
|
||||
}
|
||||
|
||||
@Operation(summary = "流式响应")
|
||||
@PostMapping(value = "/process/{kbUuid}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public SseEmitter sseAsk(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
|
||||
return knowledgeBaseService.sseAsk(kbUuid, req);
|
||||
}
|
||||
|
||||
@GetMapping("/record/search")
|
||||
|
|
|
@ -31,11 +31,24 @@ public class KnowledgeBaseQaRecord extends BaseEntity {
|
|||
@TableField("question")
|
||||
private String question;
|
||||
|
||||
@Schema(title = "最终提供给LLM的提示词")
|
||||
@TableField("prompt")
|
||||
private String prompt;
|
||||
|
||||
@Schema(title = "提供给LLM的提示词所消耗的token数量")
|
||||
@TableField("prompt_tokens")
|
||||
private Integer promptTokens;
|
||||
|
||||
@Schema(title = "答案")
|
||||
@TableField("answer")
|
||||
private String answer;
|
||||
|
||||
@Schema(title = "答案消耗的token")
|
||||
@TableField("answer_tokens")
|
||||
private Integer answerTokens;
|
||||
|
||||
@Schema(title = "提问用户id")
|
||||
@TableField("user_id")
|
||||
private Long userId;
|
||||
|
||||
}
|
||||
|
|
|
@ -14,6 +14,13 @@ public class RateLimitHelper {
|
|||
@Resource
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
/**
|
||||
* 按固定时间窗口计算请求次数
|
||||
*
|
||||
* @param requestTimesKey redis key
|
||||
* @param rateLimitConfig 请求频率限制配置
|
||||
* @return
|
||||
*/
|
||||
public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) {
|
||||
int requestCountInTimeWindow = 0;
|
||||
String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey);
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package com.moyz.adi.common.helper;
|
||||
|
||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||
import com.moyz.adi.common.entity.User;
|
||||
import com.moyz.adi.common.interfaces.TriConsumer;
|
||||
import com.moyz.adi.common.util.LocalCache;
|
||||
import com.moyz.adi.common.vo.AnswerMeta;
|
||||
import com.moyz.adi.common.vo.PromptMeta;
|
||||
import com.moyz.adi.common.vo.SseAskParams;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class SSEEmitterHelper {
|
||||
|
||||
@Resource
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
@Resource
|
||||
private RateLimitHelper rateLimitHelper;
|
||||
|
||||
public void process(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
|
||||
|
||||
//rate limit by system
|
||||
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
||||
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
||||
sendErrorMsg(sseEmitter, "访问太过频繁");
|
||||
return;
|
||||
}
|
||||
|
||||
//Check: If still waiting response
|
||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
||||
if (StringUtils.isNotBlank(askingVal)) {
|
||||
sendErrorMsg(sseEmitter, "正在回复中...");
|
||||
return;
|
||||
}
|
||||
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
||||
try {
|
||||
sseEmitter.send(SseEmitter.event().name("start"));
|
||||
} catch (IOException e) {
|
||||
log.error("error", e);
|
||||
sseEmitter.completeWithError(e);
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
return;
|
||||
}
|
||||
|
||||
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
||||
|
||||
sseEmitter.onCompletion(() -> {
|
||||
log.info("response complete,uid:{}", user.getId());
|
||||
});
|
||||
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
|
||||
sseEmitter.onError(
|
||||
throwable -> {
|
||||
try {
|
||||
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
||||
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
|
||||
} catch (IOException e) {
|
||||
log.error("error", e);
|
||||
} finally {
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
}
|
||||
}
|
||||
);
|
||||
new LLMContext(sseAskParams.getModelName()).getLLMService().sseChat(sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||
try {
|
||||
consumer.accept((String) response, (PromptMeta) promptMeta, (AnswerMeta) answerMeta);
|
||||
} catch (Exception e) {
|
||||
log.error("error:", e);
|
||||
} finally {
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
|
||||
try {
|
||||
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
sseEmitter.complete();
|
||||
}
|
||||
}
|
|
@ -1,23 +1,27 @@
|
|||
package com.moyz.adi.common.interfaces;
|
||||
|
||||
import com.fasterxml.jackson.databind.util.JSONPObject;
|
||||
import com.moyz.adi.common.util.JsonUtil;
|
||||
import com.moyz.adi.common.util.LocalCache;
|
||||
import com.moyz.adi.common.vo.AnswerMeta;
|
||||
import com.moyz.adi.common.vo.ChatMeta;
|
||||
import com.moyz.adi.common.vo.QuestionMeta;
|
||||
import com.moyz.adi.common.vo.PromptMeta;
|
||||
import com.moyz.adi.common.vo.SseAskParams;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.TokenStream;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.Proxy;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@Slf4j
|
||||
|
@ -32,7 +36,7 @@ public abstract class AbstractLLMService<T> {
|
|||
protected StreamingChatLanguageModel streamingChatLanguageModel;
|
||||
protected ChatLanguageModel chatLanguageModel;
|
||||
|
||||
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy){
|
||||
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy) {
|
||||
this.modelName = modelName;
|
||||
this.proxy = proxy;
|
||||
String st = LocalCache.CONFIGS.get(settingName);
|
||||
|
@ -66,11 +70,13 @@ public abstract class AbstractLLMService<T> {
|
|||
|
||||
protected abstract StreamingChatLanguageModel buildStreamingChatLLM();
|
||||
|
||||
public String chat(ChatMessage chatMessage) {
|
||||
return getChatLLM().generate(chatMessage).content().text();
|
||||
protected abstract String parseError(Object error);
|
||||
|
||||
public Response<AiMessage> chat(ChatMessage chatMessage) {
|
||||
return getChatLLM().generate(chatMessage);
|
||||
}
|
||||
|
||||
public void sseChat(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) {
|
||||
public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||
|
||||
//create chat assistant
|
||||
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
|
||||
|
@ -98,7 +104,7 @@ public abstract class AbstractLLMService<T> {
|
|||
.onComplete((response) -> {
|
||||
log.info("返回数据结束了:{}", response);
|
||||
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
|
||||
QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid);
|
||||
PromptMeta questionMeta = new PromptMeta(response.tokenUsage().inputTokenCount(), questionUuid);
|
||||
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
|
||||
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
|
||||
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
|
||||
|
@ -116,7 +122,11 @@ public abstract class AbstractLLMService<T> {
|
|||
.onError((error) -> {
|
||||
log.error("stream error", error);
|
||||
try {
|
||||
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage()));
|
||||
String errorMsg = parseError(error);
|
||||
if(StringUtils.isBlank(errorMsg)){
|
||||
errorMsg = error.getMessage();
|
||||
}
|
||||
params.getSseEmitter().send(SseEmitter.event().name("error").data(errorMsg));
|
||||
} catch (IOException e) {
|
||||
log.error("sse error", e);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
|||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.moyz.adi.common.base.ThreadContext;
|
||||
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||
import com.moyz.adi.common.dto.AskReq;
|
||||
import com.moyz.adi.common.entity.Conversation;
|
||||
import com.moyz.adi.common.entity.ConversationMessage;
|
||||
|
@ -13,15 +12,14 @@ import com.moyz.adi.common.entity.UserDayCost;
|
|||
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
|
||||
import com.moyz.adi.common.enums.ErrorEnum;
|
||||
import com.moyz.adi.common.exception.BaseException;
|
||||
import com.moyz.adi.common.helper.LLMContext;
|
||||
import com.moyz.adi.common.helper.QuotaHelper;
|
||||
import com.moyz.adi.common.helper.RateLimitHelper;
|
||||
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||
import com.moyz.adi.common.mapper.ConversationMessageMapper;
|
||||
import com.moyz.adi.common.util.LocalCache;
|
||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
||||
import com.moyz.adi.common.util.UserUtil;
|
||||
import com.moyz.adi.common.vo.AnswerMeta;
|
||||
import com.moyz.adi.common.vo.QuestionMeta;
|
||||
import com.moyz.adi.common.vo.PromptMeta;
|
||||
import com.moyz.adi.common.vo.SseAskParams;
|
||||
import com.theokanning.openai.completion.chat.ChatMessageRole;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
|
@ -33,17 +31,13 @@ import jakarta.annotation.Resource;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
|
@ -56,9 +50,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
@Resource
|
||||
private ConversationMessageService _this;
|
||||
|
||||
@Resource
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
@Resource
|
||||
private QuotaHelper quotaHelper;
|
||||
|
||||
|
@ -70,51 +61,43 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
private ConversationService conversationService;
|
||||
|
||||
@Resource
|
||||
private RateLimitHelper rateLimitHelper;
|
||||
private SSEEmitterHelper sseEmitterHelper;
|
||||
|
||||
|
||||
public SseEmitter sseAsk(AskReq askReq) {
|
||||
SseEmitter sseEmitter = new SseEmitter();
|
||||
User user = ThreadContext.getCurrentUser();
|
||||
_this.asyncCheckAndPushToClient(sseEmitter, user, askReq);
|
||||
_this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||
try {
|
||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
||||
//check 1: still waiting response
|
||||
if (StringUtils.isNotBlank(askingVal)) {
|
||||
sendErrorMsg(sseEmitter, "正在回复中...");
|
||||
return false;
|
||||
}
|
||||
|
||||
//check 2: the conversation has been deleted
|
||||
//check 1: the conversation has been deleted
|
||||
Conversation delConv = conversationService.lambdaQuery()
|
||||
.eq(Conversation::getUuid, askReq.getConversationUuid())
|
||||
.eq(Conversation::getIsDeleted, true)
|
||||
.one();
|
||||
if (null != delConv) {
|
||||
sendErrorMsg(sseEmitter, "该对话已经删除");
|
||||
sseEmitterHelper.sendErrorMsg(sseEmitter, "该对话已经删除");
|
||||
return false;
|
||||
}
|
||||
|
||||
//check 3: conversation quota
|
||||
//check 2: conversation quota
|
||||
Long convsCount = conversationService.lambdaQuery()
|
||||
.eq(Conversation::getUserId, user.getId())
|
||||
.eq(Conversation::getIsDeleted, false)
|
||||
.count();
|
||||
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
|
||||
if (convsCount >= convsMax) {
|
||||
sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
||||
sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
||||
return false;
|
||||
}
|
||||
|
||||
//check 4: current user's quota
|
||||
//check 3: current user's quota
|
||||
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
|
||||
if (null != errorMsg) {
|
||||
sendErrorMsg(sseEmitter, errorMsg.getInfo());
|
||||
sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo());
|
||||
return false;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
|
@ -125,60 +108,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
return true;
|
||||
}
|
||||
|
||||
private void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
|
||||
try {
|
||||
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
sseEmitter.complete();
|
||||
}
|
||||
|
||||
@Async
|
||||
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||
log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
|
||||
//rate limit by system
|
||||
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
||||
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
||||
sendErrorMsg(sseEmitter, "访问太过频繁");
|
||||
return;
|
||||
}
|
||||
|
||||
//check business rules
|
||||
if (!check(sseEmitter, user, askReq)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
||||
try {
|
||||
sseEmitter.send(SseEmitter.event().name("start"));
|
||||
} catch (IOException e) {
|
||||
log.error("error", e);
|
||||
sseEmitter.completeWithError(e);
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
return;
|
||||
}
|
||||
|
||||
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
||||
|
||||
sseEmitter.onCompletion(() -> {
|
||||
log.info("response complete,uid:{}", user.getId());
|
||||
});
|
||||
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
|
||||
sseEmitter.onError(
|
||||
throwable -> {
|
||||
try {
|
||||
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
||||
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
|
||||
} catch (IOException e) {
|
||||
log.error("error", e);
|
||||
} finally {
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
}
|
||||
}
|
||||
);
|
||||
SseAskParams sseAskParams = new SseAskParams();
|
||||
sseAskParams.setModelName(askReq.getModelName());
|
||||
String prompt = askReq.getPrompt();
|
||||
if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) {
|
||||
prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark();
|
||||
|
@ -222,14 +161,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
|
||||
}
|
||||
}
|
||||
new LLMContext(askReq.getModelName()).getLLMService().sseChat(sseAskParams, (response, questionMeta, answerMeta) -> {
|
||||
try {
|
||||
_this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta);
|
||||
} catch (Exception e) {
|
||||
log.error("error:", e);
|
||||
} finally {
|
||||
stringRedisTemplate.delete(askingKey);
|
||||
}
|
||||
sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> {
|
||||
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -245,7 +178,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
}
|
||||
|
||||
@Transactional
|
||||
public void saveAfterAiResponse(User user, AskReq askReq, String response, QuestionMeta questionMeta, AnswerMeta answerMeta) {
|
||||
public void saveAfterAiResponse(User user, AskReq askReq, String response, PromptMeta questionMeta, AnswerMeta answerMeta) {
|
||||
|
||||
int secretKeyType = StringUtils.isNotBlank(user.getSecretKey()) ? AdiConstant.SECRET_KEY_TYPE_CUSTOM : AdiConstant.SECRET_KEY_TYPE_SYSTEM;
|
||||
Conversation conversation;
|
||||
|
@ -294,7 +227,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
|
||||
}
|
||||
|
||||
private void calcTodayCost(User user, Conversation conversation, QuestionMeta questionMeta, AnswerMeta answerMeta) {
|
||||
private void calcTodayCost(User user, Conversation conversation, PromptMeta questionMeta, AnswerMeta answerMeta) {
|
||||
|
||||
int todayTokenCost = questionMeta.getTokens() + answerMeta.getTokens();
|
||||
try {
|
||||
|
@ -304,20 +237,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
|||
.set(Conversation::getTokens, conversation.getTokens() + todayTokenCost)
|
||||
.update();
|
||||
|
||||
UserDayCost userDayCost = userDayCostService.getTodayCost(user);
|
||||
UserDayCost saveOrUpdateInst = new UserDayCost();
|
||||
if (null == userDayCost) {
|
||||
saveOrUpdateInst.setUserId(user.getId());
|
||||
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
|
||||
saveOrUpdateInst.setTokens(todayTokenCost);
|
||||
saveOrUpdateInst.setRequests(1);
|
||||
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
|
||||
} else {
|
||||
saveOrUpdateInst.setId(userDayCost.getId());
|
||||
saveOrUpdateInst.setTokens(userDayCost.getTokens() + todayTokenCost);
|
||||
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
|
||||
}
|
||||
userDayCostService.saveOrUpdate(saveOrUpdateInst);
|
||||
userDayCostService.appendCostToUser(user, todayTokenCost);
|
||||
} catch (Exception e) {
|
||||
log.error("calcTodayCost error", e);
|
||||
}
|
||||
|
|
|
@ -41,6 +41,11 @@ public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
|
|||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String parseError(Object error) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ChatLanguageModel buildChatLLM() {
|
||||
if (StringUtils.isBlank(setting.getApiKey())) {
|
||||
|
|
|
@ -5,13 +5,17 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||
import com.moyz.adi.common.base.ThreadContext;
|
||||
import com.moyz.adi.common.entity.KnowledgeBase;
|
||||
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
||||
import com.moyz.adi.common.entity.User;
|
||||
import com.moyz.adi.common.exception.BaseException;
|
||||
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
|
||||
|
||||
@Slf4j
|
||||
|
@ -21,6 +25,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
|
|||
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
|
||||
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
|
||||
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
|
||||
wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false);
|
||||
if (!ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||
wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId());
|
||||
}
|
||||
|
@ -31,6 +36,36 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
|
|||
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新的QA记录
|
||||
*
|
||||
* @param knowledgeBase 所属的知识库
|
||||
* @param question 用户的原始问题
|
||||
* @param prompt 根据{question}生成的最终提示词
|
||||
* @param promptTokens 提示词消耗的token
|
||||
* @param answer 答案
|
||||
* @param answerTokens 答案消耗的token
|
||||
* @return
|
||||
*/
|
||||
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) {
|
||||
String uuid = UUID.randomUUID().toString().replace("-", "");
|
||||
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
|
||||
newObj.setKbId(knowledgeBase.getId());
|
||||
newObj.setKbUuid((knowledgeBase.getUuid()));
|
||||
newObj.setUuid(uuid);
|
||||
newObj.setUserId(user.getId());
|
||||
newObj.setQuestion(question);
|
||||
newObj.setPrompt(prompt);
|
||||
newObj.setPromptTokens(promptTokens);
|
||||
newObj.setAnswer(answer);
|
||||
newObj.setAnswerTokens(answerTokens);
|
||||
baseMapper.insert(newObj);
|
||||
|
||||
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
|
||||
wrapper.eq(KnowledgeBaseQaRecord::getUuid, uuid);
|
||||
return baseMapper.selectOne(wrapper);
|
||||
}
|
||||
|
||||
public boolean softDelele(String uuid) {
|
||||
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||
|
|
|
@ -7,29 +7,35 @@ import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
|||
import com.moyz.adi.common.base.ThreadContext;
|
||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||
import com.moyz.adi.common.dto.KbEditReq;
|
||||
import com.moyz.adi.common.dto.QAReq;
|
||||
import com.moyz.adi.common.entity.*;
|
||||
import com.moyz.adi.common.exception.BaseException;
|
||||
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||
import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
|
||||
import com.moyz.adi.common.util.BizPager;
|
||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
||||
import com.moyz.adi.common.vo.SseAskParams;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.parser.TextDocumentParser;
|
||||
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
|
||||
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.text.MessageFormat;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.*;
|
||||
|
||||
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
|
||||
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
||||
|
@ -39,6 +45,10 @@ import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.load
|
|||
@Service
|
||||
public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> {
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private KnowledgeBaseService _this;
|
||||
|
||||
@Resource
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
|
@ -54,6 +64,12 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
|||
@Resource
|
||||
private FileService fileService;
|
||||
|
||||
@Resource
|
||||
private SSEEmitterHelper sseEmitterHelper;
|
||||
|
||||
@Resource
|
||||
private UserDayCostService userDayCostService;
|
||||
|
||||
public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) {
|
||||
String uuid = kbEditReq.getUuid();
|
||||
KnowledgeBase knowledgeBase = new KnowledgeBase();
|
||||
|
@ -184,7 +200,6 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public boolean softDelete(String uuid) {
|
||||
checkPrivilege(null, uuid);
|
||||
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||
|
@ -193,8 +208,29 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
|||
.update();
|
||||
}
|
||||
|
||||
public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question, String modelName) {
|
||||
public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
|
||||
checkRequestTimesOrThrow();
|
||||
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName);
|
||||
|
||||
Response<AiMessage> ar = responsePair.getRight();
|
||||
int inputTokenCount = ar.tokenUsage().inputTokenCount();
|
||||
int outputTokenCount = ar.tokenUsage().outputTokenCount();
|
||||
userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount);
|
||||
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount);
|
||||
}
|
||||
|
||||
public SseEmitter sseAsk(String kbUuid, QAReq req) {
|
||||
checkRequestTimesOrThrow();
|
||||
SseEmitter sseEmitter = new SseEmitter();
|
||||
_this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req);
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* 知识库问答限额判断
|
||||
*/
|
||||
private void checkRequestTimesOrThrow() {
|
||||
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
|
||||
String askTimes = stringRedisTemplate.opsForValue().get(key);
|
||||
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
|
||||
|
@ -202,19 +238,24 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
|||
throw new BaseException(A_QA_ASK_LIMIT);
|
||||
}
|
||||
stringRedisTemplate.opsForValue().increment(key);
|
||||
}
|
||||
|
||||
|
||||
@Async
|
||||
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
|
||||
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
|
||||
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||
String answer = ragService.findAnswer(kbUuid, question, modelName);
|
||||
String uuid = UUID.randomUUID().toString().replace("-", "");
|
||||
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
|
||||
newObj.setKbId(knowledgeBase.getId());
|
||||
newObj.setKbUuid((knowledgeBase.getUuid()));
|
||||
newObj.setUuid(uuid);
|
||||
newObj.setUserId(ThreadContext.getCurrentUserId());
|
||||
newObj.setQuestion(question);
|
||||
newObj.setAnswer(answer);
|
||||
knowledgeBaseQaRecordService.save(newObj);
|
||||
return knowledgeBaseQaRecordService.lambdaQuery().eq(KnowledgeBaseQaRecord::getUuid, uuid).one();
|
||||
|
||||
String prompt = ragService.retrieveAndCreatePrompt(kbUuid, req.getQuestion()).text();
|
||||
SseAskParams sseAskParams = new SseAskParams();
|
||||
sseAskParams.setSystemMessage(StringUtils.EMPTY);
|
||||
sseAskParams.setSseEmitter(sseEmitter);
|
||||
sseAskParams.setUserMessage(prompt);
|
||||
sseAskParams.setModelName(req.getModelName());
|
||||
sseEmitterHelper.process(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), prompt, promptMeta.getTokens(), response, answerMeta.getTokens());
|
||||
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
|
||||
});
|
||||
}
|
||||
|
||||
public KnowledgeBase getOrThrow(String kbUuid) {
|
||||
|
@ -248,4 +289,5 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
|||
throw new BaseException(A_USER_NOT_AUTH);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,7 +4,10 @@ import com.moyz.adi.common.cosntant.AdiConstant;
|
|||
import com.moyz.adi.common.enums.ErrorEnum;
|
||||
import com.moyz.adi.common.exception.BaseException;
|
||||
import com.moyz.adi.common.interfaces.AbstractLLMService;
|
||||
import com.moyz.adi.common.util.JsonUtil;
|
||||
import com.moyz.adi.common.vo.OpenAiSetting;
|
||||
import com.theokanning.openai.OpenAiError;
|
||||
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
@ -12,6 +15,7 @@ import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
|||
import lombok.experimental.Accessors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
|
||||
import java.net.Proxy;
|
||||
import java.time.Duration;
|
||||
|
@ -49,6 +53,16 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
|
|||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String parseError(Object error) {
|
||||
if(error instanceof OpenAiHttpException){
|
||||
OpenAiHttpException openAiHttpException = (OpenAiHttpException)error;
|
||||
OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class);
|
||||
return openAiError.getError().getMessage();
|
||||
}
|
||||
return Strings.EMPTY;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ChatLanguageModel buildChatLLM() {
|
||||
if (StringUtils.isBlank(setting.getSecretKey())) {
|
||||
|
|
|
@ -49,4 +49,9 @@ public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
|
|||
.secretKey(setting.getSecretKey())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String parseError(Object error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,22 +1,30 @@
|
|||
package com.moyz.adi.common.service;
|
||||
|
||||
import com.moyz.adi.common.helper.LLMContext;
|
||||
import com.moyz.adi.common.interfaces.TriConsumer;
|
||||
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
|
||||
import com.moyz.adi.common.vo.AnswerMeta;
|
||||
import com.moyz.adi.common.vo.PromptMeta;
|
||||
import dev.langchain4j.data.document.Document;
|
||||
import dev.langchain4j.data.document.DocumentSplitter;
|
||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.commons.lang3.tuple.Triple;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
|
@ -104,16 +112,7 @@ public class RAGService {
|
|||
getEmbeddingStoreIngestor().ingest(document);
|
||||
}
|
||||
|
||||
/**
|
||||
* 召回并搜索
|
||||
*
|
||||
* @param kbUuid 知识库uuid
|
||||
* @param question 用户的问题
|
||||
* @param modelName LLM model name
|
||||
* @return
|
||||
*/
|
||||
public String findAnswer(String kbUuid, String question, String modelName) {
|
||||
|
||||
public Prompt retrieveAndCreatePrompt(String kbUuid, String question) {
|
||||
// Embed the question
|
||||
Embedding questionEmbedding = embeddingModel.embed(question).content();
|
||||
|
||||
|
@ -129,10 +128,25 @@ public class RAGService {
|
|||
.collect(joining("\n\n"));
|
||||
|
||||
if (StringUtils.isBlank(information)) {
|
||||
return StringUtils.EMPTY;
|
||||
return null;
|
||||
}
|
||||
Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
||||
return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
||||
}
|
||||
|
||||
return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
||||
/**
|
||||
* 召回并提问
|
||||
*
|
||||
* @param kbUuid 知识库uuid
|
||||
* @param question 用户的问题
|
||||
* @param modelName LLM model name
|
||||
* @return
|
||||
*/
|
||||
public Pair<String, Response<AiMessage>> retrieveAndAsk(String kbUuid, String question, String modelName) {
|
||||
Prompt prompt = retrieveAndCreatePrompt(kbUuid, question);
|
||||
if (null == prompt) {
|
||||
return null;
|
||||
}
|
||||
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
||||
return new ImmutablePair<>(prompt.text(), response);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,23 @@ import java.util.List;
|
|||
@Service
|
||||
public class UserDayCostService extends ServiceImpl<UserDayCostMapper, UserDayCost> {
|
||||
|
||||
public void appendCostToUser(User user, int tokens) {
|
||||
UserDayCost userDayCost = getTodayCost(user);
|
||||
UserDayCost saveOrUpdateInst = new UserDayCost();
|
||||
if (null == userDayCost) {
|
||||
saveOrUpdateInst.setUserId(user.getId());
|
||||
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
|
||||
saveOrUpdateInst.setTokens(tokens);
|
||||
saveOrUpdateInst.setRequests(1);
|
||||
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
|
||||
} else {
|
||||
saveOrUpdateInst.setId(userDayCost.getId());
|
||||
saveOrUpdateInst.setTokens(userDayCost.getTokens() + tokens);
|
||||
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
|
||||
}
|
||||
saveOrUpdate(saveOrUpdateInst);
|
||||
}
|
||||
|
||||
public CostStat costStatByUser(long userId) {
|
||||
CostStat result = new CostStat();
|
||||
|
||||
|
|
|
@ -6,6 +6,6 @@ import lombok.Data;
|
|||
@Data
|
||||
@AllArgsConstructor
|
||||
public class ChatMeta {
|
||||
private QuestionMeta question;
|
||||
private PromptMeta question;
|
||||
private AnswerMeta answer;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import lombok.Data;
|
|||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class QuestionMeta {
|
||||
public class PromptMeta {
|
||||
private Integer tokens;
|
||||
private String uuid;
|
||||
}
|
|
@ -21,4 +21,5 @@ public class SseAskParams {
|
|||
|
||||
private SseEmitter sseEmitter;
|
||||
|
||||
private String modelName;
|
||||
}
|
||||
|
|
|
@ -200,11 +200,11 @@ COMMENT ON COLUMN public.adi_prompt.is_deleted IS '0:未删除;1:已删除';
|
|||
CREATE TABLE public.adi_sys_config
|
||||
(
|
||||
id bigserial primary key,
|
||||
name character varying(100) DEFAULT ''::character varying NOT NULL,
|
||||
name character varying(100) DEFAULT ''::character varying NOT NULL,
|
||||
value character varying(1000) DEFAULT ''::character varying NOT NULL,
|
||||
create_time timestamp DEFAULT localtimestamp NOT NULL,
|
||||
update_time timestamp DEFAULT localtimestamp NOT NULL,
|
||||
is_deleted boolean DEFAULT false NOT NULL
|
||||
create_time timestamp DEFAULT localtimestamp NOT NULL,
|
||||
update_time timestamp DEFAULT localtimestamp NOT NULL,
|
||||
is_deleted boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE public.adi_sys_config IS '系统配置表';
|
||||
|
@ -497,7 +497,10 @@ create table adi_knowledge_base_qa_record
|
|||
kb_id bigint DEFAULT '0'::bigint NOT NULL,
|
||||
kb_uuid varchar(32) default ''::character varying not null,
|
||||
question varchar(1000) default ''::character varying not null,
|
||||
prompt text default ''::character varying not null,
|
||||
prompt_tokens integer DEFAULT 0 NOT NULL,
|
||||
answer text default ''::character varying not null,
|
||||
answer_tokens integer DEFAULT 0 NOT NULL,
|
||||
source_file_ids varchar(500) default ''::character varying not null,
|
||||
user_id bigint default '0' NOT NULL,
|
||||
create_time timestamp default CURRENT_TIMESTAMP not null,
|
||||
|
@ -511,10 +514,16 @@ comment on column adi_knowledge_base_qa_record.kb_id is '所属知识库id';
|
|||
|
||||
comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.question is '问题';
|
||||
comment on column adi_knowledge_base_qa_record.question is '用户的原始问题';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.prompt is '提供给LLM的提示词';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.prompt_tokens is '提示词消耗的token';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.answer is '答案';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.answer_tokens is '答案消耗的token';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开';
|
||||
|
||||
comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';
|
||||
|
|
Loading…
Reference in New Issue