fix:知识库token计算问题
This commit is contained in:
parent
b1a9cba31e
commit
498e9131b9
18
README.md
18
README.md
|
@ -5,11 +5,11 @@
|
||||||
|
|
||||||
> **该项目如对您有帮助,欢迎点赞**
|
> **该项目如对您有帮助,欢迎点赞**
|
||||||
|
|
||||||
### 体验网址
|
## 体验网址
|
||||||
|
|
||||||
[http://www.aideepin.com](http://www.aideepin.com/)
|
[http://www.aideepin.com](http://www.aideepin.com/)
|
||||||
|
|
||||||
### 功能点
|
## 功能点
|
||||||
|
|
||||||
* 注册&登录
|
* 注册&登录
|
||||||
* 多会话(多角色)
|
* 多会话(多角色)
|
||||||
|
@ -19,14 +19,14 @@
|
||||||
* 基于大模型的知识库(RAG)
|
* 基于大模型的知识库(RAG)
|
||||||
* 多模型随意切换
|
* 多模型随意切换
|
||||||
|
|
||||||
### 接入的模型:
|
## 接入的模型:
|
||||||
|
|
||||||
* ChatGPT 3.5
|
* ChatGPT 3.5
|
||||||
* 通义千问
|
* 通义千问
|
||||||
* 文心一言
|
* 文心一言
|
||||||
* DALL-E 2
|
* DALL-E 2
|
||||||
|
|
||||||
### 技术栈
|
## 技术栈
|
||||||
|
|
||||||
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
|
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
|
||||||
|
|
||||||
|
@ -44,9 +44,9 @@ springboot3.0.5
|
||||||
|
|
||||||
vue3+typescript+pnpm
|
vue3+typescript+pnpm
|
||||||
|
|
||||||
### 如何部署
|
## 如何部署
|
||||||
|
|
||||||
#### 初始化
|
### 初始化
|
||||||
|
|
||||||
初始化数据库
|
初始化数据库
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name
|
||||||
* redis: application-[dev|prod].xml中的spring.data.redis
|
* redis: application-[dev|prod].xml中的spring.data.redis
|
||||||
* mail: application.xml中的spring.mail
|
* mail: application.xml中的spring.mail
|
||||||
|
|
||||||
#### 编译及运行
|
### 编译及运行
|
||||||
|
|
||||||
* 进入项目
|
* 进入项目
|
||||||
|
|
||||||
|
@ -100,12 +100,12 @@ docker run -d \
|
||||||
aideepin:0.0.1
|
aideepin:0.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
### 待办:
|
## 待办:
|
||||||
|
|
||||||
增强RAG
|
增强RAG
|
||||||
|
|
||||||
|
|
||||||
### 截图
|
## 截图
|
||||||
|
|
||||||
**AI聊天:**
|
**AI聊天:**
|
||||||
![1691583184761](image/README/1691583184761.png)
|
![1691583184761](image/README/1691583184761.png)
|
||||||
|
|
|
@ -5,12 +5,15 @@ import com.moyz.adi.common.dto.QAReq;
|
||||||
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
||||||
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
|
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
|
||||||
import com.moyz.adi.common.service.KnowledgeBaseService;
|
import com.moyz.adi.common.service.KnowledgeBaseService;
|
||||||
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import jakarta.validation.constraints.Min;
|
import jakarta.validation.constraints.Min;
|
||||||
import jakarta.validation.constraints.NotNull;
|
import jakarta.validation.constraints.NotNull;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
@Tag(name = "知识库问答controller")
|
@Tag(name = "知识库问答controller")
|
||||||
@RequestMapping("/knowledge-base/qa/")
|
@RequestMapping("/knowledge-base/qa/")
|
||||||
|
@ -25,7 +28,13 @@ public class KnowledgeBaseQAController {
|
||||||
|
|
||||||
@PostMapping("/ask/{kbUuid}")
|
@PostMapping("/ask/{kbUuid}")
|
||||||
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
|
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
|
||||||
return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName());
|
return knowledgeBaseService.ask(kbUuid, req.getQuestion(), req.getModelName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "流式响应")
|
||||||
|
@PostMapping(value = "/process/{kbUuid}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
|
public SseEmitter sseAsk(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
|
||||||
|
return knowledgeBaseService.sseAsk(kbUuid, req);
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/record/search")
|
@GetMapping("/record/search")
|
||||||
|
|
|
@ -31,11 +31,24 @@ public class KnowledgeBaseQaRecord extends BaseEntity {
|
||||||
@TableField("question")
|
@TableField("question")
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
|
@Schema(title = "最终提供给LLM的提示词")
|
||||||
|
@TableField("prompt")
|
||||||
|
private String prompt;
|
||||||
|
|
||||||
|
@Schema(title = "提供给LLM的提示词所消耗的token数量")
|
||||||
|
@TableField("prompt_tokens")
|
||||||
|
private Integer promptTokens;
|
||||||
|
|
||||||
@Schema(title = "答案")
|
@Schema(title = "答案")
|
||||||
@TableField("answer")
|
@TableField("answer")
|
||||||
private String answer;
|
private String answer;
|
||||||
|
|
||||||
|
@Schema(title = "答案消耗的token")
|
||||||
|
@TableField("answer_tokens")
|
||||||
|
private Integer answerTokens;
|
||||||
|
|
||||||
@Schema(title = "提问用户id")
|
@Schema(title = "提问用户id")
|
||||||
@TableField("user_id")
|
@TableField("user_id")
|
||||||
private Long userId;
|
private Long userId;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,13 @@ public class RateLimitHelper {
|
||||||
@Resource
|
@Resource
|
||||||
private StringRedisTemplate stringRedisTemplate;
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 按固定时间窗口计算请求次数
|
||||||
|
*
|
||||||
|
* @param requestTimesKey redis key
|
||||||
|
* @param rateLimitConfig 请求频率限制配置
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) {
|
public boolean checkRequestTimes(String requestTimesKey, RequestRateLimit rateLimitConfig) {
|
||||||
int requestCountInTimeWindow = 0;
|
int requestCountInTimeWindow = 0;
|
||||||
String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey);
|
String rateLimitVal = stringRedisTemplate.opsForValue().get(requestTimesKey);
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
package com.moyz.adi.common.helper;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||||
|
import com.moyz.adi.common.entity.User;
|
||||||
|
import com.moyz.adi.common.interfaces.TriConsumer;
|
||||||
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
|
import com.moyz.adi.common.vo.AnswerMeta;
|
||||||
|
import com.moyz.adi.common.vo.PromptMeta;
|
||||||
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.text.MessageFormat;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class SSEEmitterHelper {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private RateLimitHelper rateLimitHelper;
|
||||||
|
|
||||||
|
public void process(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||||
|
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
|
||||||
|
|
||||||
|
//rate limit by system
|
||||||
|
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
||||||
|
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
||||||
|
sendErrorMsg(sseEmitter, "访问太过频繁");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Check: If still waiting response
|
||||||
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||||
|
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
||||||
|
if (StringUtils.isNotBlank(askingVal)) {
|
||||||
|
sendErrorMsg(sseEmitter, "正在回复中...");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
||||||
|
try {
|
||||||
|
sseEmitter.send(SseEmitter.event().name("start"));
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("error", e);
|
||||||
|
sseEmitter.completeWithError(e);
|
||||||
|
stringRedisTemplate.delete(askingKey);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
||||||
|
|
||||||
|
sseEmitter.onCompletion(() -> {
|
||||||
|
log.info("response complete,uid:{}", user.getId());
|
||||||
|
});
|
||||||
|
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
|
||||||
|
sseEmitter.onError(
|
||||||
|
throwable -> {
|
||||||
|
try {
|
||||||
|
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
||||||
|
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("error", e);
|
||||||
|
} finally {
|
||||||
|
stringRedisTemplate.delete(askingKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
new LLMContext(sseAskParams.getModelName()).getLLMService().sseChat(sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||||
|
try {
|
||||||
|
consumer.accept((String) response, (PromptMeta) promptMeta, (AnswerMeta) answerMeta);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("error:", e);
|
||||||
|
} finally {
|
||||||
|
stringRedisTemplate.delete(askingKey);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
|
||||||
|
try {
|
||||||
|
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
sseEmitter.complete();
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,23 +1,27 @@
|
||||||
package com.moyz.adi.common.interfaces;
|
package com.moyz.adi.common.interfaces;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.util.JSONPObject;
|
||||||
import com.moyz.adi.common.util.JsonUtil;
|
import com.moyz.adi.common.util.JsonUtil;
|
||||||
import com.moyz.adi.common.util.LocalCache;
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
import com.moyz.adi.common.vo.AnswerMeta;
|
import com.moyz.adi.common.vo.AnswerMeta;
|
||||||
import com.moyz.adi.common.vo.ChatMeta;
|
import com.moyz.adi.common.vo.ChatMeta;
|
||||||
import com.moyz.adi.common.vo.QuestionMeta;
|
import com.moyz.adi.common.vo.PromptMeta;
|
||||||
import com.moyz.adi.common.vo.SseAskParams;
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.service.AiServices;
|
import dev.langchain4j.service.AiServices;
|
||||||
import dev.langchain4j.service.TokenStream;
|
import dev.langchain4j.service.TokenStream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.Proxy;
|
import java.net.Proxy;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -66,11 +70,13 @@ public abstract class AbstractLLMService<T> {
|
||||||
|
|
||||||
protected abstract StreamingChatLanguageModel buildStreamingChatLLM();
|
protected abstract StreamingChatLanguageModel buildStreamingChatLLM();
|
||||||
|
|
||||||
public String chat(ChatMessage chatMessage) {
|
protected abstract String parseError(Object error);
|
||||||
return getChatLLM().generate(chatMessage).content().text();
|
|
||||||
|
public Response<AiMessage> chat(ChatMessage chatMessage) {
|
||||||
|
return getChatLLM().generate(chatMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void sseChat(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) {
|
public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||||
|
|
||||||
//create chat assistant
|
//create chat assistant
|
||||||
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
|
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
|
||||||
|
@ -98,7 +104,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
.onComplete((response) -> {
|
.onComplete((response) -> {
|
||||||
log.info("返回数据结束了:{}", response);
|
log.info("返回数据结束了:{}", response);
|
||||||
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
|
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
|
||||||
QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid);
|
PromptMeta questionMeta = new PromptMeta(response.tokenUsage().inputTokenCount(), questionUuid);
|
||||||
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
|
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
|
||||||
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
|
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
|
||||||
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
|
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
|
||||||
|
@ -116,7 +122,11 @@ public abstract class AbstractLLMService<T> {
|
||||||
.onError((error) -> {
|
.onError((error) -> {
|
||||||
log.error("stream error", error);
|
log.error("stream error", error);
|
||||||
try {
|
try {
|
||||||
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage()));
|
String errorMsg = parseError(error);
|
||||||
|
if(StringUtils.isBlank(errorMsg)){
|
||||||
|
errorMsg = error.getMessage();
|
||||||
|
}
|
||||||
|
params.getSseEmitter().send(SseEmitter.event().name("error").data(errorMsg));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.error("sse error", e);
|
log.error("sse error", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.moyz.adi.common.base.ThreadContext;
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
import com.moyz.adi.common.cosntant.AdiConstant;
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
|
||||||
import com.moyz.adi.common.dto.AskReq;
|
import com.moyz.adi.common.dto.AskReq;
|
||||||
import com.moyz.adi.common.entity.Conversation;
|
import com.moyz.adi.common.entity.Conversation;
|
||||||
import com.moyz.adi.common.entity.ConversationMessage;
|
import com.moyz.adi.common.entity.ConversationMessage;
|
||||||
|
@ -13,15 +12,14 @@ import com.moyz.adi.common.entity.UserDayCost;
|
||||||
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
|
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
|
||||||
import com.moyz.adi.common.enums.ErrorEnum;
|
import com.moyz.adi.common.enums.ErrorEnum;
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
import com.moyz.adi.common.helper.LLMContext;
|
|
||||||
import com.moyz.adi.common.helper.QuotaHelper;
|
import com.moyz.adi.common.helper.QuotaHelper;
|
||||||
import com.moyz.adi.common.helper.RateLimitHelper;
|
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||||
import com.moyz.adi.common.mapper.ConversationMessageMapper;
|
import com.moyz.adi.common.mapper.ConversationMessageMapper;
|
||||||
import com.moyz.adi.common.util.LocalCache;
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
||||||
import com.moyz.adi.common.util.UserUtil;
|
import com.moyz.adi.common.util.UserUtil;
|
||||||
import com.moyz.adi.common.vo.AnswerMeta;
|
import com.moyz.adi.common.vo.AnswerMeta;
|
||||||
import com.moyz.adi.common.vo.QuestionMeta;
|
import com.moyz.adi.common.vo.PromptMeta;
|
||||||
import com.moyz.adi.common.vo.SseAskParams;
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
import com.theokanning.openai.completion.chat.ChatMessageRole;
|
import com.theokanning.openai.completion.chat.ChatMessageRole;
|
||||||
import dev.langchain4j.data.message.SystemMessage;
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
@ -33,17 +31,13 @@ import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
|
||||||
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.scheduling.annotation.Async;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.text.MessageFormat;
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND;
|
import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
|
@ -56,9 +50,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
@Resource
|
@Resource
|
||||||
private ConversationMessageService _this;
|
private ConversationMessageService _this;
|
||||||
|
|
||||||
@Resource
|
|
||||||
private StringRedisTemplate stringRedisTemplate;
|
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private QuotaHelper quotaHelper;
|
private QuotaHelper quotaHelper;
|
||||||
|
|
||||||
|
@ -70,51 +61,43 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
private ConversationService conversationService;
|
private ConversationService conversationService;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private RateLimitHelper rateLimitHelper;
|
private SSEEmitterHelper sseEmitterHelper;
|
||||||
|
|
||||||
|
|
||||||
public SseEmitter sseAsk(AskReq askReq) {
|
public SseEmitter sseAsk(AskReq askReq) {
|
||||||
SseEmitter sseEmitter = new SseEmitter();
|
SseEmitter sseEmitter = new SseEmitter();
|
||||||
User user = ThreadContext.getCurrentUser();
|
_this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
|
||||||
_this.asyncCheckAndPushToClient(sseEmitter, user, askReq);
|
|
||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) {
|
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||||
try {
|
try {
|
||||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
|
||||||
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
|
||||||
//check 1: still waiting response
|
|
||||||
if (StringUtils.isNotBlank(askingVal)) {
|
|
||||||
sendErrorMsg(sseEmitter, "正在回复中...");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
//check 2: the conversation has been deleted
|
//check 1: the conversation has been deleted
|
||||||
Conversation delConv = conversationService.lambdaQuery()
|
Conversation delConv = conversationService.lambdaQuery()
|
||||||
.eq(Conversation::getUuid, askReq.getConversationUuid())
|
.eq(Conversation::getUuid, askReq.getConversationUuid())
|
||||||
.eq(Conversation::getIsDeleted, true)
|
.eq(Conversation::getIsDeleted, true)
|
||||||
.one();
|
.one();
|
||||||
if (null != delConv) {
|
if (null != delConv) {
|
||||||
sendErrorMsg(sseEmitter, "该对话已经删除");
|
sseEmitterHelper.sendErrorMsg(sseEmitter, "该对话已经删除");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//check 3: conversation quota
|
//check 2: conversation quota
|
||||||
Long convsCount = conversationService.lambdaQuery()
|
Long convsCount = conversationService.lambdaQuery()
|
||||||
.eq(Conversation::getUserId, user.getId())
|
.eq(Conversation::getUserId, user.getId())
|
||||||
.eq(Conversation::getIsDeleted, false)
|
.eq(Conversation::getIsDeleted, false)
|
||||||
.count();
|
.count();
|
||||||
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
|
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
|
||||||
if (convsCount >= convsMax) {
|
if (convsCount >= convsMax) {
|
||||||
sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//check 4: current user's quota
|
//check 3: current user's quota
|
||||||
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
|
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
|
||||||
if (null != errorMsg) {
|
if (null != errorMsg) {
|
||||||
sendErrorMsg(sseEmitter, errorMsg.getInfo());
|
sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -125,60 +108,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
|
|
||||||
try {
|
|
||||||
sseEmitter.send(SseEmitter.event().name("error").data(errorMsg));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
sseEmitter.complete();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Async
|
@Async
|
||||||
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
|
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||||
log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
|
log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
|
||||||
//rate limit by system
|
|
||||||
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
|
||||||
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
|
||||||
sendErrorMsg(sseEmitter, "访问太过频繁");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//check business rules
|
//check business rules
|
||||||
if (!check(sseEmitter, user, askReq)) {
|
if (!check(sseEmitter, user, askReq)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
|
||||||
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
|
||||||
try {
|
|
||||||
sseEmitter.send(SseEmitter.event().name("start"));
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("error", e);
|
|
||||||
sseEmitter.completeWithError(e);
|
|
||||||
stringRedisTemplate.delete(askingKey);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
|
||||||
|
|
||||||
sseEmitter.onCompletion(() -> {
|
|
||||||
log.info("response complete,uid:{}", user.getId());
|
|
||||||
});
|
|
||||||
sseEmitter.onTimeout(() -> log.warn("sseEmitter timeout,uid:{},on timeout:{}", user.getId(), sseEmitter.getTimeout()));
|
|
||||||
sseEmitter.onError(
|
|
||||||
throwable -> {
|
|
||||||
try {
|
|
||||||
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
|
||||||
sseEmitter.send(SseEmitter.event().name("error").data(throwable.getMessage()));
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("error", e);
|
|
||||||
} finally {
|
|
||||||
stringRedisTemplate.delete(askingKey);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
SseAskParams sseAskParams = new SseAskParams();
|
SseAskParams sseAskParams = new SseAskParams();
|
||||||
|
sseAskParams.setModelName(askReq.getModelName());
|
||||||
String prompt = askReq.getPrompt();
|
String prompt = askReq.getPrompt();
|
||||||
if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) {
|
if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) {
|
||||||
prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark();
|
prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark();
|
||||||
|
@ -222,14 +161,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
new LLMContext(askReq.getModelName()).getLLMService().sseChat(sseAskParams, (response, questionMeta, answerMeta) -> {
|
sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> {
|
||||||
try {
|
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
|
||||||
_this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("error:", e);
|
|
||||||
} finally {
|
|
||||||
stringRedisTemplate.delete(askingKey);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,7 +178,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void saveAfterAiResponse(User user, AskReq askReq, String response, QuestionMeta questionMeta, AnswerMeta answerMeta) {
|
public void saveAfterAiResponse(User user, AskReq askReq, String response, PromptMeta questionMeta, AnswerMeta answerMeta) {
|
||||||
|
|
||||||
int secretKeyType = StringUtils.isNotBlank(user.getSecretKey()) ? AdiConstant.SECRET_KEY_TYPE_CUSTOM : AdiConstant.SECRET_KEY_TYPE_SYSTEM;
|
int secretKeyType = StringUtils.isNotBlank(user.getSecretKey()) ? AdiConstant.SECRET_KEY_TYPE_CUSTOM : AdiConstant.SECRET_KEY_TYPE_SYSTEM;
|
||||||
Conversation conversation;
|
Conversation conversation;
|
||||||
|
@ -294,7 +227,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void calcTodayCost(User user, Conversation conversation, QuestionMeta questionMeta, AnswerMeta answerMeta) {
|
private void calcTodayCost(User user, Conversation conversation, PromptMeta questionMeta, AnswerMeta answerMeta) {
|
||||||
|
|
||||||
int todayTokenCost = questionMeta.getTokens() + answerMeta.getTokens();
|
int todayTokenCost = questionMeta.getTokens() + answerMeta.getTokens();
|
||||||
try {
|
try {
|
||||||
|
@ -304,20 +237,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
.set(Conversation::getTokens, conversation.getTokens() + todayTokenCost)
|
.set(Conversation::getTokens, conversation.getTokens() + todayTokenCost)
|
||||||
.update();
|
.update();
|
||||||
|
|
||||||
UserDayCost userDayCost = userDayCostService.getTodayCost(user);
|
userDayCostService.appendCostToUser(user, todayTokenCost);
|
||||||
UserDayCost saveOrUpdateInst = new UserDayCost();
|
|
||||||
if (null == userDayCost) {
|
|
||||||
saveOrUpdateInst.setUserId(user.getId());
|
|
||||||
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
|
|
||||||
saveOrUpdateInst.setTokens(todayTokenCost);
|
|
||||||
saveOrUpdateInst.setRequests(1);
|
|
||||||
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
|
|
||||||
} else {
|
|
||||||
saveOrUpdateInst.setId(userDayCost.getId());
|
|
||||||
saveOrUpdateInst.setTokens(userDayCost.getTokens() + todayTokenCost);
|
|
||||||
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
|
|
||||||
}
|
|
||||||
userDayCostService.saveOrUpdate(saveOrUpdateInst);
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("calcTodayCost error", e);
|
log.error("calcTodayCost error", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,11 @@ public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String parseError(Object error) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ChatLanguageModel buildChatLLM() {
|
protected ChatLanguageModel buildChatLLM() {
|
||||||
if (StringUtils.isBlank(setting.getApiKey())) {
|
if (StringUtils.isBlank(setting.getApiKey())) {
|
||||||
|
|
|
@ -5,13 +5,17 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||||
import com.moyz.adi.common.base.ThreadContext;
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
|
import com.moyz.adi.common.entity.KnowledgeBase;
|
||||||
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
|
||||||
|
import com.moyz.adi.common.entity.User;
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
|
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
|
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -21,6 +25,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
|
||||||
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
|
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
|
||||||
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
|
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
|
||||||
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
|
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
|
||||||
|
wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false);
|
||||||
if (!ThreadContext.getCurrentUser().getIsAdmin()) {
|
if (!ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||||
wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId());
|
wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId());
|
||||||
}
|
}
|
||||||
|
@ -31,6 +36,36 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
|
||||||
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
|
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建新的QA记录
|
||||||
|
*
|
||||||
|
* @param knowledgeBase 所属的知识库
|
||||||
|
* @param question 用户的原始问题
|
||||||
|
* @param prompt 根据{question}生成的最终提示词
|
||||||
|
* @param promptTokens 提示词消耗的token
|
||||||
|
* @param answer 答案
|
||||||
|
* @param answerTokens 答案消耗的token
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) {
|
||||||
|
String uuid = UUID.randomUUID().toString().replace("-", "");
|
||||||
|
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
|
||||||
|
newObj.setKbId(knowledgeBase.getId());
|
||||||
|
newObj.setKbUuid((knowledgeBase.getUuid()));
|
||||||
|
newObj.setUuid(uuid);
|
||||||
|
newObj.setUserId(user.getId());
|
||||||
|
newObj.setQuestion(question);
|
||||||
|
newObj.setPrompt(prompt);
|
||||||
|
newObj.setPromptTokens(promptTokens);
|
||||||
|
newObj.setAnswer(answer);
|
||||||
|
newObj.setAnswerTokens(answerTokens);
|
||||||
|
baseMapper.insert(newObj);
|
||||||
|
|
||||||
|
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
|
||||||
|
wrapper.eq(KnowledgeBaseQaRecord::getUuid, uuid);
|
||||||
|
return baseMapper.selectOne(wrapper);
|
||||||
|
}
|
||||||
|
|
||||||
public boolean softDelele(String uuid) {
|
public boolean softDelele(String uuid) {
|
||||||
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||||
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||||
|
|
|
@ -7,29 +7,35 @@ import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||||
import com.moyz.adi.common.base.ThreadContext;
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||||
import com.moyz.adi.common.dto.KbEditReq;
|
import com.moyz.adi.common.dto.KbEditReq;
|
||||||
|
import com.moyz.adi.common.dto.QAReq;
|
||||||
import com.moyz.adi.common.entity.*;
|
import com.moyz.adi.common.entity.*;
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
|
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||||
import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
|
import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
|
||||||
import com.moyz.adi.common.util.BizPager;
|
import com.moyz.adi.common.util.BizPager;
|
||||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
||||||
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
import dev.langchain4j.data.document.Document;
|
import dev.langchain4j.data.document.Document;
|
||||||
import dev.langchain4j.data.document.parser.TextDocumentParser;
|
import dev.langchain4j.data.document.parser.TextDocumentParser;
|
||||||
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
|
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
|
||||||
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
|
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.scheduling.annotation.Async;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
import java.text.MessageFormat;
|
import java.text.MessageFormat;
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
|
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
|
||||||
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
||||||
|
@ -39,6 +45,10 @@ import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.load
|
||||||
@Service
|
@Service
|
||||||
public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> {
|
public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> {
|
||||||
|
|
||||||
|
@Lazy
|
||||||
|
@Resource
|
||||||
|
private KnowledgeBaseService _this;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private StringRedisTemplate stringRedisTemplate;
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
|
|
||||||
|
@ -54,6 +64,12 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
@Resource
|
@Resource
|
||||||
private FileService fileService;
|
private FileService fileService;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private SSEEmitterHelper sseEmitterHelper;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private UserDayCostService userDayCostService;
|
||||||
|
|
||||||
public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) {
|
public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) {
|
||||||
String uuid = kbEditReq.getUuid();
|
String uuid = kbEditReq.getUuid();
|
||||||
KnowledgeBase knowledgeBase = new KnowledgeBase();
|
KnowledgeBase knowledgeBase = new KnowledgeBase();
|
||||||
|
@ -184,7 +200,6 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public boolean softDelete(String uuid) {
|
public boolean softDelete(String uuid) {
|
||||||
checkPrivilege(null, uuid);
|
checkPrivilege(null, uuid);
|
||||||
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||||
|
@ -193,8 +208,29 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
.update();
|
.update();
|
||||||
}
|
}
|
||||||
|
|
||||||
public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question, String modelName) {
|
public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
|
||||||
|
checkRequestTimesOrThrow();
|
||||||
|
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||||
|
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName);
|
||||||
|
|
||||||
|
Response<AiMessage> ar = responsePair.getRight();
|
||||||
|
int inputTokenCount = ar.tokenUsage().inputTokenCount();
|
||||||
|
int outputTokenCount = ar.tokenUsage().outputTokenCount();
|
||||||
|
userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount);
|
||||||
|
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SseEmitter sseAsk(String kbUuid, QAReq req) {
|
||||||
|
checkRequestTimesOrThrow();
|
||||||
|
SseEmitter sseEmitter = new SseEmitter();
|
||||||
|
_this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req);
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库问答限额判断
|
||||||
|
*/
|
||||||
|
private void checkRequestTimesOrThrow() {
|
||||||
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
|
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
|
||||||
String askTimes = stringRedisTemplate.opsForValue().get(key);
|
String askTimes = stringRedisTemplate.opsForValue().get(key);
|
||||||
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
|
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
|
||||||
|
@ -202,19 +238,24 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
throw new BaseException(A_QA_ASK_LIMIT);
|
throw new BaseException(A_QA_ASK_LIMIT);
|
||||||
}
|
}
|
||||||
stringRedisTemplate.opsForValue().increment(key);
|
stringRedisTemplate.opsForValue().increment(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Async
|
||||||
|
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
|
||||||
|
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
|
||||||
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||||
String answer = ragService.findAnswer(kbUuid, question, modelName);
|
|
||||||
String uuid = UUID.randomUUID().toString().replace("-", "");
|
String prompt = ragService.retrieveAndCreatePrompt(kbUuid, req.getQuestion()).text();
|
||||||
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
|
SseAskParams sseAskParams = new SseAskParams();
|
||||||
newObj.setKbId(knowledgeBase.getId());
|
sseAskParams.setSystemMessage(StringUtils.EMPTY);
|
||||||
newObj.setKbUuid((knowledgeBase.getUuid()));
|
sseAskParams.setSseEmitter(sseEmitter);
|
||||||
newObj.setUuid(uuid);
|
sseAskParams.setUserMessage(prompt);
|
||||||
newObj.setUserId(ThreadContext.getCurrentUserId());
|
sseAskParams.setModelName(req.getModelName());
|
||||||
newObj.setQuestion(question);
|
sseEmitterHelper.process(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||||
newObj.setAnswer(answer);
|
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), prompt, promptMeta.getTokens(), response, answerMeta.getTokens());
|
||||||
knowledgeBaseQaRecordService.save(newObj);
|
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
|
||||||
return knowledgeBaseQaRecordService.lambdaQuery().eq(KnowledgeBaseQaRecord::getUuid, uuid).one();
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public KnowledgeBase getOrThrow(String kbUuid) {
|
public KnowledgeBase getOrThrow(String kbUuid) {
|
||||||
|
@ -248,4 +289,5 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
throw new BaseException(A_USER_NOT_AUTH);
|
throw new BaseException(A_USER_NOT_AUTH);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,10 @@ import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.enums.ErrorEnum;
|
import com.moyz.adi.common.enums.ErrorEnum;
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
import com.moyz.adi.common.interfaces.AbstractLLMService;
|
import com.moyz.adi.common.interfaces.AbstractLLMService;
|
||||||
|
import com.moyz.adi.common.util.JsonUtil;
|
||||||
import com.moyz.adi.common.vo.OpenAiSetting;
|
import com.moyz.adi.common.vo.OpenAiSetting;
|
||||||
|
import com.theokanning.openai.OpenAiError;
|
||||||
|
import dev.ai4j.openai4j.OpenAiHttpException;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
|
@ -12,6 +15,7 @@ import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||||
import lombok.experimental.Accessors;
|
import lombok.experimental.Accessors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.logging.log4j.util.Strings;
|
||||||
|
|
||||||
import java.net.Proxy;
|
import java.net.Proxy;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
@ -49,6 +53,16 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
|
||||||
return builder.build();
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String parseError(Object error) {
|
||||||
|
if(error instanceof OpenAiHttpException){
|
||||||
|
OpenAiHttpException openAiHttpException = (OpenAiHttpException)error;
|
||||||
|
OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class);
|
||||||
|
return openAiError.getError().getMessage();
|
||||||
|
}
|
||||||
|
return Strings.EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ChatLanguageModel buildChatLLM() {
|
protected ChatLanguageModel buildChatLLM() {
|
||||||
if (StringUtils.isBlank(setting.getSecretKey())) {
|
if (StringUtils.isBlank(setting.getSecretKey())) {
|
||||||
|
|
|
@ -49,4 +49,9 @@ public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
|
||||||
.secretKey(setting.getSecretKey())
|
.secretKey(setting.getSecretKey())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String parseError(Object error) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,22 +1,30 @@
|
||||||
package com.moyz.adi.common.service;
|
package com.moyz.adi.common.service;
|
||||||
|
|
||||||
import com.moyz.adi.common.helper.LLMContext;
|
import com.moyz.adi.common.helper.LLMContext;
|
||||||
|
import com.moyz.adi.common.interfaces.TriConsumer;
|
||||||
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
|
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
|
||||||
|
import com.moyz.adi.common.vo.AnswerMeta;
|
||||||
|
import com.moyz.adi.common.vo.PromptMeta;
|
||||||
import dev.langchain4j.data.document.Document;
|
import dev.langchain4j.data.document.Document;
|
||||||
import dev.langchain4j.data.document.DocumentSplitter;
|
import dev.langchain4j.data.document.DocumentSplitter;
|
||||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.apache.commons.lang3.tuple.Triple;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@ -104,16 +112,7 @@ public class RAGService {
|
||||||
getEmbeddingStoreIngestor().ingest(document);
|
getEmbeddingStoreIngestor().ingest(document);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
public Prompt retrieveAndCreatePrompt(String kbUuid, String question) {
|
||||||
* 召回并搜索
|
|
||||||
*
|
|
||||||
* @param kbUuid 知识库uuid
|
|
||||||
* @param question 用户的问题
|
|
||||||
* @param modelName LLM model name
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public String findAnswer(String kbUuid, String question, String modelName) {
|
|
||||||
|
|
||||||
// Embed the question
|
// Embed the question
|
||||||
Embedding questionEmbedding = embeddingModel.embed(question).content();
|
Embedding questionEmbedding = embeddingModel.embed(question).content();
|
||||||
|
|
||||||
|
@ -129,10 +128,25 @@ public class RAGService {
|
||||||
.collect(joining("\n\n"));
|
.collect(joining("\n\n"));
|
||||||
|
|
||||||
if (StringUtils.isBlank(information)) {
|
if (StringUtils.isBlank(information)) {
|
||||||
return StringUtils.EMPTY;
|
return null;
|
||||||
|
}
|
||||||
|
return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
||||||
}
|
}
|
||||||
Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
|
||||||
|
|
||||||
return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
/**
|
||||||
|
* 召回并提问
|
||||||
|
*
|
||||||
|
* @param kbUuid 知识库uuid
|
||||||
|
* @param question 用户的问题
|
||||||
|
* @param modelName LLM model name
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Pair<String, Response<AiMessage>> retrieveAndAsk(String kbUuid, String question, String modelName) {
|
||||||
|
Prompt prompt = retrieveAndCreatePrompt(kbUuid, question);
|
||||||
|
if (null == prompt) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
||||||
|
return new ImmutablePair<>(prompt.text(), response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,23 @@ import java.util.List;
|
||||||
@Service
|
@Service
|
||||||
public class UserDayCostService extends ServiceImpl<UserDayCostMapper, UserDayCost> {
|
public class UserDayCostService extends ServiceImpl<UserDayCostMapper, UserDayCost> {
|
||||||
|
|
||||||
|
public void appendCostToUser(User user, int tokens) {
|
||||||
|
UserDayCost userDayCost = getTodayCost(user);
|
||||||
|
UserDayCost saveOrUpdateInst = new UserDayCost();
|
||||||
|
if (null == userDayCost) {
|
||||||
|
saveOrUpdateInst.setUserId(user.getId());
|
||||||
|
saveOrUpdateInst.setDay(LocalDateTimeUtil.getToday());
|
||||||
|
saveOrUpdateInst.setTokens(tokens);
|
||||||
|
saveOrUpdateInst.setRequests(1);
|
||||||
|
saveOrUpdateInst.setSecretKeyType(UserUtil.getSecretType(user));
|
||||||
|
} else {
|
||||||
|
saveOrUpdateInst.setId(userDayCost.getId());
|
||||||
|
saveOrUpdateInst.setTokens(userDayCost.getTokens() + tokens);
|
||||||
|
saveOrUpdateInst.setRequests(userDayCost.getRequests() + 1);
|
||||||
|
}
|
||||||
|
saveOrUpdate(saveOrUpdateInst);
|
||||||
|
}
|
||||||
|
|
||||||
public CostStat costStatByUser(long userId) {
|
public CostStat costStatByUser(long userId) {
|
||||||
CostStat result = new CostStat();
|
CostStat result = new CostStat();
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,6 @@ import lombok.Data;
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class ChatMeta {
|
public class ChatMeta {
|
||||||
private QuestionMeta question;
|
private PromptMeta question;
|
||||||
private AnswerMeta answer;
|
private AnswerMeta answer;
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class QuestionMeta {
|
public class PromptMeta {
|
||||||
private Integer tokens;
|
private Integer tokens;
|
||||||
private String uuid;
|
private String uuid;
|
||||||
}
|
}
|
|
@ -21,4 +21,5 @@ public class SseAskParams {
|
||||||
|
|
||||||
private SseEmitter sseEmitter;
|
private SseEmitter sseEmitter;
|
||||||
|
|
||||||
|
private String modelName;
|
||||||
}
|
}
|
||||||
|
|
|
@ -497,7 +497,10 @@ create table adi_knowledge_base_qa_record
|
||||||
kb_id bigint DEFAULT '0'::bigint NOT NULL,
|
kb_id bigint DEFAULT '0'::bigint NOT NULL,
|
||||||
kb_uuid varchar(32) default ''::character varying not null,
|
kb_uuid varchar(32) default ''::character varying not null,
|
||||||
question varchar(1000) default ''::character varying not null,
|
question varchar(1000) default ''::character varying not null,
|
||||||
|
prompt text default ''::character varying not null,
|
||||||
|
prompt_tokens integer DEFAULT 0 NOT NULL,
|
||||||
answer text default ''::character varying not null,
|
answer text default ''::character varying not null,
|
||||||
|
answer_tokens integer DEFAULT 0 NOT NULL,
|
||||||
source_file_ids varchar(500) default ''::character varying not null,
|
source_file_ids varchar(500) default ''::character varying not null,
|
||||||
user_id bigint default '0' NOT NULL,
|
user_id bigint default '0' NOT NULL,
|
||||||
create_time timestamp default CURRENT_TIMESTAMP not null,
|
create_time timestamp default CURRENT_TIMESTAMP not null,
|
||||||
|
@ -511,10 +514,16 @@ comment on column adi_knowledge_base_qa_record.kb_id is '所属知识库id';
|
||||||
|
|
||||||
comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid';
|
comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid';
|
||||||
|
|
||||||
comment on column adi_knowledge_base_qa_record.question is '问题';
|
comment on column adi_knowledge_base_qa_record.question is '用户的原始问题';
|
||||||
|
|
||||||
|
comment on column adi_knowledge_base_qa_record.prompt is '提供给LLM的提示词';
|
||||||
|
|
||||||
|
comment on column adi_knowledge_base_qa_record.prompt_tokens is '提示词消耗的token';
|
||||||
|
|
||||||
comment on column adi_knowledge_base_qa_record.answer is '答案';
|
comment on column adi_knowledge_base_qa_record.answer is '答案';
|
||||||
|
|
||||||
|
comment on column adi_knowledge_base_qa_record.answer_tokens is '答案消耗的token';
|
||||||
|
|
||||||
comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开';
|
comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开';
|
||||||
|
|
||||||
comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';
|
comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';
|
||||||
|
|
Loading…
Reference in New Issue