模型配置放到表中

This commit is contained in:
moyangzhan 2024-05-05 18:40:32 +08:00
parent b9f2d78a4f
commit 2341c92cf8
26 changed files with 383 additions and 149 deletions

View File

@ -70,25 +70,39 @@ vue3+typescript+pnpm
openai的secretKey
```plaintext
update adi_sys_config set value = '{"secret_key":"my_openai_secret_key","models":["gpt-3.5-turbo"]}' where name = 'openai_setting';
update adi_sys_config set value = '{"secret_key":"my_openai_secret_key"}' where name = 'openai_setting';
```
灵积大模型平台的apiKey
```plaintext
update adi_sys_config set value = '{"api_key":"my_dashcope_api_key","models":["my model name,eg:qwen-max"]}' where name = 'dashscope_setting';
update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name = 'dashscope_setting';
```
千帆大模型平台的配置
```plaintext
update adi_sys_config set value = '{"api_key":"my_qianfan_api_key","secret_key":"my_qianfan_secret_key","models":["my model name,eg:ERNIE-Bot"]}' where name = 'qianfan_setting';
update adi_sys_config set value = '{"api_key":"my_qianfan_api_key","secret_key":"my_qianfan_secret_key"}' where name = 'qianfan_setting';
```
ollama的配置
```
update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting';
update adi_sys_config set value = '{"base_url":"my_ollama_base_url"}' where name = 'ollama_setting';
```
* 启用模型或新增模型
```
-- Enable model
update adi_ai_model set is_enable = true where name = 'gpt-3.5-turbo';
update adi_ai_model set is_enable = true where name = 'dall-e-2';
update adi_ai_model set is_enable = true where name = 'qwen-turbo';
update adi_ai_model set is_enable = true where name = 'ernie-3.5-8k-0205';
update adi_ai_model set is_enable = true where name = 'tinydolphin';
-- Add new model
INSERT INTO adi_ai_model (name, type, platform, is_enable) VALUES ('vicuna', 'text', 'ollama', true);
```
* 填充搜索引擎的配置

View File

@ -1,6 +1,7 @@
package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.KbQaRecordDto;
import com.moyz.adi.common.dto.QAReq;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
@ -38,7 +39,7 @@ public class KnowledgeBaseQAController {
}
@GetMapping("/record/search")
public Page<KnowledgeBaseQaRecord> list(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
public Page<KbQaRecordDto> list(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return knowledgeBaseQaRecordService.search(kbUuid, keyword, currentPage, pageSize);
}

View File

@ -102,6 +102,21 @@ public class AdiConstant {
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
public static class ModelPlatform {
public static final String OPENAI = "openai";
public static final String DASHSCOPE = "dashscope";
public static final String QIANFAN = "qianfan";
public static final String OLLAMA = "ollama";
}
public static class ModelType {
public static final String TEXT = "text";
public static final String IMAGE = "image";
public static final String EMBEDDING = "embedding";
public static final String RERANK = "rerank";
}
public static class SearchEngineName {
public static final String GOOGLE = "google";
public static final String BING = "bing";

View File

@ -0,0 +1,44 @@
package com.moyz.adi.common.dto;
import com.baomidou.mybatisplus.annotation.TableField;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
@Data
public class ConvMsgDto {
@JsonIgnore
private Long id;
@Schema(title = "消息的uuid")
private String uuid;
@Schema(title = "父级消息id")
private Long parentMessageId;
@Schema(title = "对话的消息")
@TableField("remark")
private String remark;
@Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手")
private Integer messageRole;
@Schema(title = "消耗的token数量")
private Integer tokens;
@Schema(title = "创建时间")
private LocalDateTime createTime;
@Schema(title = "model id")
private Long aiModelId;
@Schema(title = "model platform name")
private String aiModelPlatform;
@Schema(title = "子级消息一般指的是AI的响应")
private List<ConvMsgDto> children;
}

View File

@ -11,5 +11,5 @@ public class ConvMsgListResp {
private String minMsgUuid;
private List<ConvMsgResp> msgList;
private List<ConvMsgDto> msgList;
}

View File

@ -0,0 +1,39 @@
package com.moyz.adi.common.dto;
import com.baomidou.mybatisplus.annotation.TableField;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
public class KbQaRecordDto {
@Schema(title = "uuid")
private String uuid;
@Schema(title = "知识库uuid")
private String kbUuid;
@Schema(title = "来源文档id,以逗号隔开")
private String sourceFileIds;
@Schema(title = "问题")
private String question;
@Schema(title = "最终提供给LLM的提示词")
@TableField("prompt")
private String prompt;
@Schema(title = "提供给LLM的提示词所消耗的token数量")
private Integer promptTokens;
@Schema(title = "答案")
private String answer;
@Schema(title = "答案消耗的token")
private Integer answerTokens;
@Schema(title = "ai model id")
private Long aiModelId;
@Schema(title = "ai model platform")
private String aiModelPlatform;
}

View File

@ -1,6 +1,5 @@
package com.moyz.adi.common.entity;
import com.moyz.adi.common.enums.AiModelStatus;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
@ -11,15 +10,23 @@ import lombok.Data;
@Schema(title = "AiModel对象", description = "AI模型表")
public class AiModel extends BaseEntity {
@Schema(title = "模型类型:text,image,embedding,rerank")
@TableField("type")
private String type;
@Schema(title = "模型名称")
@TableField("name")
private String name;
@Schema(title = "模型所属平台")
@TableField("platform")
private String platform;
@Schema(title = "说明")
@TableField("remark")
private String remark;
@Schema(title = "状态(1:正常使用,2:不可用)")
@TableField("model_status")
private AiModelStatus modelStatus;
@Schema(title = "状态(1:正常使用,0:不可用)")
@TableField("is_enable")
private Boolean isEnable;
}

View File

@ -58,7 +58,7 @@ public class ConversationMessage extends BaseEntity {
@TableField("understand_context_msg_pair_num")
private Integer understandContextMsgPairNum;
@Schema(name = "LLM name")
@TableField("language_model_name")
private String languageModelName;
@Schema(name = "adi_ai_model id")
@TableField("ai_model_id")
private Long aiModelId;
}

View File

@ -51,4 +51,7 @@ public class KnowledgeBaseQaRecord extends BaseEntity {
@TableField("user_id")
private Long userId;
@Schema(title = "adi_ai_model id")
@TableField("ai_model_id")
private Long aiModelId;
}

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.vo.ImageModelInfo;
import lombok.extern.slf4j.Slf4j;
@ -35,12 +36,12 @@ public class ImageModelContext {
}
}
public static void addImageModelService(String modelServiceKey, AbstractImageModelService modelService) {
public static void addImageModelService(AbstractImageModelService modelService) {
ImageModelInfo imageModelInfo = new ImageModelInfo();
imageModelInfo.setModelService(modelService);
imageModelInfo.setModelName(modelServiceKey);
imageModelInfo.setModelName(modelService.getAiModel().getName());
imageModelInfo.setEnable(modelService.isEnabled());
NAME_TO_MODEL.put(modelServiceKey, imageModelInfo);
NAME_TO_MODEL.put(modelService.getAiModel().getName(), imageModelInfo);
}
public AbstractImageModelService getModelService() {

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
@ -34,12 +35,12 @@ public class LLMContext {
}
}
public static void addLLMService(String llmServiceKey, AbstractLLMService llmService) {
public static void addLLMService(AbstractLLMService llmService) {
LLMModelInfo llmModelInfo = new LLMModelInfo();
llmModelInfo.setModelName(llmServiceKey);
llmModelInfo.setModelName(llmService.getAiModel().getName());
llmModelInfo.setEnable(llmService.isEnabled());
llmModelInfo.setLlmService(llmService);
NAME_TO_MODEL.put(llmServiceKey, llmModelInfo);
NAME_TO_MODEL.put(llmService.getAiModel().getName(), llmModelInfo);
}
public AbstractLLMService getLLMService() {

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
@ -24,7 +25,7 @@ public abstract class AbstractImageModelService<T> {
protected Proxy proxy;
protected String modelName;
protected AiModel aiModel;
protected T setting;
@ -39,8 +40,8 @@ public abstract class AbstractImageModelService<T> {
protected ImageModel imageModel;
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName;
public AbstractImageModelService(AiModel aiModel, String settingName, Class<T> clazz) {
this.aiModel = aiModel;
String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz);
}
@ -88,4 +89,8 @@ public abstract class AbstractImageModelService<T> {
public abstract List<String> editImage(User user, AiImage aiImage);
public abstract List<String> createImageVariation(User user, AiImage aiImage);
public AiModel getAiModel() {
return aiModel;
}
}

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.service.RAGService;
import com.moyz.adi.common.util.JsonUtil;
@ -37,9 +38,9 @@ public abstract class AbstractLLMService<T> {
protected Proxy proxy;
protected String modelName;
protected AiModel aiModel;
protected T setting;
protected T modelPlatformSetting;
protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel;
@ -50,10 +51,10 @@ public abstract class AbstractLLMService<T> {
private RAGService queryCompressingRagService;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName;
public AbstractLLMService(AiModel aiModel, String settingName, Class<T> clazz) {
this.aiModel = aiModel;
String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz);
modelPlatformSetting = JsonUtil.fromJson(st, clazz);
}
public AbstractLLMService setProxy(Proxy proxy) {
@ -61,7 +62,7 @@ public abstract class AbstractLLMService<T> {
return this;
}
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService) {
queryCompressingRagService = ragService;
return this;
}
@ -201,4 +202,8 @@ public abstract class AbstractLLMService<T> {
return chatAssistant.chat(messageId, userMessage);
}
}
public AiModel getAiModel() {
return aiModel;
}
}

View File

@ -0,0 +1,51 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.mapper.AiModelMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import java.util.List;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
import static com.moyz.adi.common.util.LocalCache.MODEL_NAME_TO_OBJ;
@Slf4j
@Service
public class AiModelService extends ServiceImpl<AiModelMapper, AiModel> {
@Scheduled(fixedDelay = 5 * 60 * 1000)
public void initAll() {
List<AiModel> all = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getIsDeleted, false)
.list();
for (AiModel model : all) {
MODEL_NAME_TO_OBJ.put(model.getName(), model);
MODEL_ID_TO_OBJ.put(model.getId(), model);
}
}
public List<AiModel> listBy(String platform, String type) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getPlatform, platform)
.eq(AiModel::getType, type)
.eq(AiModel::getIsDeleted, false)
.list();
}
public AiModel getByName(String modelName) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getName, modelName)
.eq(AiModel::getIsDeleted, false)
.one();
}
public Long getIdByName(String modelName) {
AiModel aiModel = this.getByName(modelName);
return null == aiModel ? 0l : aiModel.getId();
}
}

View File

@ -61,6 +61,9 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
@Resource
private SSEEmitterHelper sseEmitterHelper;
@Resource
private AiModelService aiModelService;
public SseEmitter sseAsk(AskReq askReq) {
SseEmitter sseEmitter = new SseEmitter();
@ -143,30 +146,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable())) {
sseAskParams.setMessageId(askReq.getConversationUuid());
}
// List<ConversationMessage> historyMsgList = this.lambdaQuery()
// .eq(ConversationMessage::getUserId, user.getId())
// .eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid())
// .orderByDesc(ConversationMessage::getId)
// .last("limit " + user.getUnderstandContextMsgPairNum() * 2)
// .list();
// if (!historyMsgList.isEmpty()) {
// ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO));
// historyMsgList.sort(Comparator.comparing(ConversationMessage::getId));
// for (ConversationMessage historyMsg : historyMsgList) {
// if (ChatMessageRoleEnum.USER.getValue().equals(historyMsg.getMessageRole())) {
// UserMessage userMessage = UserMessage.from(historyMsg.getRemark());
// chatMemory.add(userMessage);
// } else if (ChatMessageRoleEnum.SYSTEM.getValue().equals(historyMsg.getMessageRole())) {
// SystemMessage systemMessage = SystemMessage.from(historyMsg.getRemark());
// chatMemory.add(systemMessage);
// }else if (ChatMessageRoleEnum.ASSISTANT.getValue().equals(historyMsg.getMessageRole())) {
// AiMessage aiMessage = AiMessage.from(historyMsg.getRemark());
// chatMemory.add(aiMessage);
// }
// }
// sseAskParams.setChatMemory(chatMemory);
// }
// }
}
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
@ -228,7 +207,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
aiAnswer.setTokens(answerMeta.getTokens());
aiAnswer.setParentMessageId(promptMsg.getId());
aiAnswer.setSecretKeyType(secretKeyType);
aiAnswer.setLanguageModelName(askReq.getModelName());
aiAnswer.setAiModelId(aiModelService.getIdByName(askReq.getModelName()));
baseMapper.insert(aiAnswer);
calcTodayCost(user, conversation, questionMeta, answerMeta);

View File

@ -5,7 +5,8 @@ import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.ConvDto;
import com.moyz.adi.common.dto.ConvEditReq;
import com.moyz.adi.common.dto.ConvMsgListResp;
import com.moyz.adi.common.dto.ConvMsgResp;
import com.moyz.adi.common.dto.ConvMsgDto;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.Conversation;
import com.moyz.adi.common.entity.ConversationMessage;
import com.moyz.adi.common.exception.BaseException;
@ -19,9 +20,12 @@ import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_EXIST;
import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_NOT_EXIST;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
import static com.moyz.adi.common.util.LocalCache.MODEL_NAME_TO_OBJ;
@Slf4j
@Service
@ -33,6 +37,9 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
@Resource
private ConversationMessageService conversationMessageService;
@Resource
private AiModelService aiModelService;
public List<ConvDto> listByUser() {
List<Conversation> list = this.lambdaQuery()
.eq(Conversation::getUserId, ThreadContext.getCurrentUserId())
@ -81,7 +88,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
return b;
}).getUuid();
//Wrap question content
List<ConvMsgResp> userMessages = MPPageUtil.convertToList(questions, ConvMsgResp.class);
List<ConvMsgDto> userMessages = MPPageUtil.convertToList(questions, ConvMsgDto.class);
ConvMsgListResp result = new ConvMsgListResp(minUuid, userMessages);
//Wrap answer content
@ -95,9 +102,14 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
//Fill AI answer to the request of user
result.getMsgList().forEach(item -> {
List<ConvMsgResp> children = MPPageUtil.convertToList(idToMessages.get(item.getId()), ConvMsgResp.class);
List<ConvMsgDto> children = MPPageUtil.convertToList(idToMessages.get(item.getId()), ConvMsgDto.class);
if (children.size() > 1) {
children = children.stream().sorted(Comparator.comparing(ConvMsgResp::getCreateTime).reversed()).collect(Collectors.toList());
children = children.stream().sorted(Comparator.comparing(ConvMsgDto::getCreateTime).reversed()).collect(Collectors.toList());
}
for (ConvMsgDto convMsgDto : children) {
AiModel aiModel = MODEL_ID_TO_OBJ.get(convMsgDto.getAiModelId());
convMsgDto.setAiModelPlatform(null == aiModel ? "" : aiModel.getPlatform());
}
item.setChildren(children);
});

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.DashScopeSetting;
@ -11,8 +12,6 @@ import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.net.Proxy;
import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
/**
@ -21,23 +20,23 @@ import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
@Slf4j
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
public DashScopeLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class);
public DashScopeLLMService(AiModel aiModel) {
super(aiModel, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getApiKey());
return StringUtils.isNotBlank(modelPlatformSetting.getApiKey()) && aiModel.getIsEnable();
}
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) {
if (StringUtils.isBlank(modelPlatformSetting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
}
return QwenStreamingChatModel.builder()
.apiKey(setting.getApiKey())
.modelName(modelName)
.apiKey(modelPlatformSetting.getApiKey())
.modelName(aiModel.getName())
.build();
}
@ -48,12 +47,12 @@ public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
@Override
protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) {
if (StringUtils.isBlank(modelPlatformSetting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
}
return QwenChatModel.builder()
.apiKey(setting.getApiKey())
.modelName(modelName)
.apiKey(modelPlatformSetting.getApiKey())
.modelName(aiModel.getName())
.build();
}

View File

@ -1,19 +1,24 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.searchengine.GoogleSearchEngine;
import com.moyz.adi.common.searchengine.SearchEngineContext;
import dev.langchain4j.model.openai.OpenAiModelName;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.List;
import java.util.function.Function;
@Slf4j
@Service
@ -28,6 +33,9 @@ public class Initializer {
@Value("${adi.proxy.http-port:0}")
protected int proxyHttpPort;
@Resource
private AiModelService aiModelService;
@Resource
private SysConfigService sysConfigService;
@ -37,52 +45,53 @@ public class Initializer {
@PostConstruct
public void init() {
sysConfigService.reload();
aiModelService.initAll();
Proxy proxy = null;
Proxy proxy;
if (proxyEnable) {
proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
} else {
proxy = null;
}
//openai
String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING);
if (openaiModels.length == 0) {
log.warn("openai service is disabled");
}
for (String model : openaiModels) {
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
}
initLLMService(AdiConstant.ModelPlatform.OPENAI, (model) -> new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
//dashscope
String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING);
if (dashscopeModels.length == 0) {
log.warn("dashscope service is disabled");
}
for (String model : dashscopeModels) {
LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
initLLMService(AdiConstant.ModelPlatform.DASHSCOPE, (model) -> new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
//qianfan
String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING);
if (qianfanModels.length == 0) {
log.warn("qianfan service is disabled");
}
for (String model : qianfanModels) {
LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
initLLMService(AdiConstant.ModelPlatform.QIANFAN, (model) -> new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
//ollama
String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING);
if (ollamaModels.length == 0) {
log.warn("ollama service is disabled");
}
for (String model : ollamaModels) {
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
initLLMService(AdiConstant.ModelPlatform.OLLAMA, (model) -> new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
//openai image model
initImageModelService(AdiConstant.ModelPlatform.OPENAI, (model) -> new OpenAiImageModelService(model).setProxy(proxy));
//search engine
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
}
private void initLLMService(String platform, Function<AiModel, AbstractLLMService> function) {
List<AiModel> models = aiModelService.listBy(platform, AdiConstant.ModelType.TEXT);
if (CollectionUtils.isEmpty(models)) {
log.warn("{} service is disabled", platform);
}
for (AiModel model : models) {
LLMContext.addLLMService(function.apply(model));
}
}
private void initImageModelService(String platform, Function<AiModel, AbstractImageModelService> function) {
List<AiModel> models = aiModelService.listBy(platform, AdiConstant.ModelType.IMAGE);
if (CollectionUtils.isEmpty(models)) {
log.warn("{} service is disabled", platform);
}
for (AiModel model : models) {
ImageModelContext.addImageModelService(function.apply(model));
}
}
}

View File

@ -5,11 +5,15 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.KbQaRecordDto;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
import com.moyz.adi.common.util.MPPageUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
@ -17,12 +21,16 @@ import org.springframework.stereotype.Service;
import java.util.UUID;
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
@Slf4j
@Service
public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRecordMapper, KnowledgeBaseQaRecord> {
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
@Resource
private AiModelService aiModelService;
public Page<KbQaRecordDto> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false);
@ -33,7 +41,15 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
wrapper.like(KnowledgeBaseQaRecord::getQuestion, keyword);
}
wrapper.orderByDesc(KnowledgeBaseQaRecord::getUpdateTime);
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
Page<KnowledgeBaseQaRecord> page = baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
Page<KbQaRecordDto> result = new Page<>();
MPPageUtil.convertToPage(page, result, KbQaRecordDto.class, (t1, t2) -> {
AiModel aiModel = MODEL_ID_TO_OBJ.get(t1.getAiModelId());
t2.setAiModelPlatform(null == aiModel ? "" : aiModel.getPlatform());
return t2;
});
return result;
}
/**
@ -45,9 +61,10 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
* @param promptTokens 提示词消耗的token
* @param answer 答案
* @param answerTokens 答案消耗的token
* @param modelName ai model name
* @return
*/
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) {
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens, String modelName) {
String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
newObj.setKbId(knowledgeBase.getId());
@ -59,6 +76,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
newObj.setPromptTokens(promptTokens);
newObj.setAnswer(answer);
newObj.setAnswerTokens(answerTokens);
newObj.setAiModelId(aiModelService.getIdByName(modelName));
baseMapper.insert(newObj);
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();

View File

@ -244,7 +244,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
int inputTokenCount = ar.tokenUsage().inputTokenCount();
int outputTokenCount = ar.tokenUsage().outputTokenCount();
userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount);
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount);
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount, modelName);
}
public SseEmitter sseAsk(String kbUuid, QAReq req) {
@ -328,7 +328,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
sseAskParams.setUserMessage(promptText);
sseAskParams.setModelName(req.getModelName());
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens());
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens(), req.getModelName());
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
});
}

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.OllamaSetting;
import dev.langchain4j.model.chat.ChatLanguageModel;
@ -12,20 +13,20 @@ import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTI
public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
public OllamaLLMService(String modelName) {
super(modelName, OLLAMA_SETTING, OllamaSetting.class);
public OllamaLLMService(AiModel aiModel) {
super(aiModel, OLLAMA_SETTING, OllamaSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getBaseUrl());
return StringUtils.isNotBlank(modelPlatformSetting.getBaseUrl()) && aiModel.getIsEnable();
}
@Override
protected ChatLanguageModel buildChatLLM() {
return OllamaChatModel.builder()
.baseUrl(setting.getBaseUrl())
.modelName(modelName)
.baseUrl(modelPlatformSetting.getBaseUrl())
.modelName(aiModel.getName())
.temperature(0.0)
.build();
}
@ -33,8 +34,8 @@ public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
return OllamaStreamingChatModel.builder()
.baseUrl(setting.getBaseUrl())
.modelName(modelName)
.baseUrl(modelPlatformSetting.getBaseUrl())
.modelName(aiModel.getName())
.build();
}

View File

@ -3,6 +3,7 @@ package com.moyz.adi.common.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
@ -45,13 +46,13 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
@Resource
private ObjectMapper objectMapper;
public OpenAiImageModelService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
public OpenAiImageModelService(AiModel aiModel) {
super(aiModel, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey());
return StringUtils.isNotBlank(setting.getSecretKey()) && aiModel.getIsEnable();
}
@Override
@ -60,7 +61,7 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiImageModel.OpenAiImageModelBuilder builder = OpenAiImageModel.builder()
.modelName(modelName)
.modelName(aiModel.getName())
.apiKey(setting.getSecretKey())
.user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService;
@ -27,24 +28,24 @@ import java.time.temporal.ChronoUnit;
@Accessors(chain = true)
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
public OpenAiLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
public OpenAiLLMService(AiModel model) {
super(model, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey());
return StringUtils.isNotBlank(modelPlatformSetting.getSecretKey()) && aiModel.getIsEnable();
}
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) {
if (StringUtils.isBlank(modelPlatformSetting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel
.builder()
.modelName(modelName)
.apiKey(setting.getSecretKey())
.modelName(aiModel.getName())
.apiKey(modelPlatformSetting.getSecretKey())
.timeout(Duration.of(60, ChronoUnit.SECONDS));
if (null != proxy) {
builder.proxy(proxy);
@ -54,8 +55,8 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
@Override
protected String parseError(Object error) {
if(error instanceof OpenAiHttpException){
OpenAiHttpException openAiHttpException = (OpenAiHttpException)error;
if (error instanceof OpenAiHttpException) {
OpenAiHttpException openAiHttpException = (OpenAiHttpException) error;
OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class);
return openAiError.getError().getMessage();
}
@ -64,10 +65,10 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
@Override
protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) {
if (StringUtils.isBlank(modelPlatformSetting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(setting.getSecretKey());
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(modelPlatformSetting.getSecretKey());
if (null != proxy) {
builder.proxy(proxy);
}

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.QianFanSetting;
import dev.langchain4j.model.chat.ChatLanguageModel;
@ -18,35 +19,35 @@ import org.apache.commons.lang3.StringUtils;
@Accessors(chain = true)
public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
public QianFanLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class);
public QianFanLLMService(AiModel aiModel) {
super(aiModel, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNoneBlank(setting.getApiKey(), setting.getSecretKey());
return StringUtils.isNoneBlank(modelPlatformSetting.getApiKey(), modelPlatformSetting.getSecretKey()) && aiModel.getIsEnable();
}
@Override
protected ChatLanguageModel buildChatLLM() {
return QianfanChatModel.builder()
.modelName(modelName)
.modelName(aiModel.getName())
.temperature(0.7)
.topP(1.0)
.maxRetries(1)
.apiKey(setting.getApiKey())
.secretKey(setting.getSecretKey())
.apiKey(modelPlatformSetting.getApiKey())
.secretKey(modelPlatformSetting.getSecretKey())
.build();
}
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
return QianfanStreamingChatModel.builder()
.modelName(modelName)
.modelName(aiModel.getName())
.temperature(0.7)
.topP(1.0)
.apiKey(setting.getApiKey())
.secretKey(setting.getSecretKey())
.apiKey(modelPlatformSetting.getApiKey())
.secretKey(modelPlatformSetting.getSecretKey())
.build();
}

View File

@ -1,14 +1,19 @@
package com.moyz.adi.common.util;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.vo.RequestRateLimit;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class LocalCache {
public static final Map<String, String> CONFIGS = new HashMap<>();
public static final Map<String, String> CONFIGS = new ConcurrentHashMap<>();
public static RequestRateLimit TEXT_RATE_LIMIT_CONFIG;
public static RequestRateLimit IMAGE_RATE_LIMIT_CONFIG;
public static Map<String, AiModel> MODEL_NAME_TO_OBJ = new ConcurrentHashMap<>();
public static Map<Long, AiModel> MODEL_ID_TO_OBJ = new ConcurrentHashMap<>();
}

View File

@ -50,25 +50,38 @@ COMMENT ON COLUMN public.adi_ai_image.is_deleted IS 'Flag indicating whether the
CREATE TABLE public.adi_ai_model
(
id bigserial primary key,
name character varying(45) DEFAULT ''::character varying NOT NULL,
remark character varying(1000),
model_status smallint DEFAULT '1'::smallint NOT NULL,
name varchar(45) default '' not null,
type varchar(45) default 'llm' not null,
remark varchar(1000) default '',
platform varchar(45) default '' not null,
max_tokens int default 0 not null,
is_enable boolean default false NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL,
CONSTRAINT adi_ai_model_model_status_check CHECK ((model_status = ANY (ARRAY [1, 2])))
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_ai_model IS 'ai模型';
COMMENT ON COLUMN public.adi_ai_model.type IS 'The type of the AI model,eg: text,image,embedding,rerank';
COMMENT ON COLUMN public.adi_ai_model.name IS 'The name of the AI model';
COMMENT ON COLUMN public.adi_ai_model.remark IS 'Additional remarks about the AI model';
COMMENT ON COLUMN public.adi_ai_model.model_status IS '1: Normal usage, 2: Not available';
COMMENT ON COLUMN public.adi_ai_model.platform IS 'eg: openai,dashscope,qianfan,ollama';
COMMENT ON COLUMN public.adi_ai_model.max_tokens IS 'The maximum number of tokens that can be generated';
COMMENT ON COLUMN public.adi_ai_model.is_enable IS '1: Normal usage, 0: Not available';
COMMENT ON COLUMN public.adi_ai_model.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_ai_model.update_time IS 'Timestamp of record last update, automatically updated on each update';
INSERT INTO adi_ai_model (name, type, platform, max_tokens, is_enable)
VALUES ('gpt-3.5-turbo', 'text', 'openai', 2048, false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('dall-e-2', 'image', 'openai', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('qwen-turbo', 'text', 'dashscope', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('ernie-3.5-8k-0205', 'text', 'qianfan', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('tinydolphin', 'text', 'ollama', false);
CREATE TABLE public.adi_conversation
(
id bigserial primary key,
@ -103,8 +116,8 @@ CREATE TABLE public.adi_conversation_message
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
message_role integer DEFAULT 1 NOT NULL,
tokens integer DEFAULT 0 NOT NULL,
language_model_name character varying(32) DEFAULT ''::character varying NOT NULL,
user_id bigint DEFAULT '0'::bigint NOT NULL,
ai_model_id bigint default 0 not null,
secret_key_type integer DEFAULT 1 NOT NULL,
understand_context_msg_pair_num integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
@ -131,6 +144,8 @@ COMMENT ON COLUMN public.adi_conversation_message.language_model_name IS 'LLM na
COMMENT ON COLUMN public.adi_conversation_message.user_id IS '用户ID';
COMMENT ON COLUMN public.adi_conversation_message.ai_model_id IS 'adi_ai_model id';
COMMENT ON COLUMN public.adi_conversation_message.secret_key_type IS '加密密钥类型';
COMMENT ON COLUMN public.adi_conversation_message.understand_context_msg_pair_num IS '上下文消息对数量';
@ -381,6 +396,12 @@ CREATE TRIGGER trigger_user_day_cost_update_time
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
create trigger trigger_ai_model
before update
on adi_ai_model
for each row
execute procedure update_modified_column();
INSERT INTO adi_sys_config (name, value)
VALUES ('openai_setting', '{"secret_key":"","models":[]}');
INSERT INTO adi_sys_config (name, value)
@ -549,6 +570,7 @@ create table adi_knowledge_base_qa_record
answer_tokens integer DEFAULT 0 NOT NULL,
source_file_ids varchar(500) default ''::character varying not null,
user_id bigint default '0' NOT NULL,
ai_model_id bigint default 0 not null,
create_time timestamp default CURRENT_TIMESTAMP not null,
update_time timestamp default CURRENT_TIMESTAMP not null,
is_deleted boolean default false not null