模型配置放到表中

This commit is contained in:
moyangzhan 2024-05-05 18:40:32 +08:00
parent b9f2d78a4f
commit 2341c92cf8
26 changed files with 383 additions and 149 deletions

View File

@ -70,25 +70,39 @@ vue3+typescript+pnpm
openai的secretKey openai的secretKey
```plaintext ```plaintext
update adi_sys_config set value = '{"secret_key":"my_openai_secret_key","models":["gpt-3.5-turbo"]}' where name = 'openai_setting'; update adi_sys_config set value = '{"secret_key":"my_openai_secret_key"}' where name = 'openai_setting';
``` ```
灵积大模型平台的apiKey 灵积大模型平台的apiKey
```plaintext ```plaintext
update adi_sys_config set value = '{"api_key":"my_dashcope_api_key","models":["my model name,eg:qwen-max"]}' where name = 'dashscope_setting'; update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name = 'dashscope_setting';
``` ```
千帆大模型平台的配置 千帆大模型平台的配置
```plaintext ```plaintext
update adi_sys_config set value = '{"api_key":"my_qianfan_api_key","secret_key":"my_qianfan_secret_key","models":["my model name,eg:ERNIE-Bot"]}' where name = 'qianfan_setting'; update adi_sys_config set value = '{"api_key":"my_qianfan_api_key","secret_key":"my_qianfan_secret_key"}' where name = 'qianfan_setting';
``` ```
ollama的配置 ollama的配置
``` ```
update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting'; update adi_sys_config set value = '{"base_url":"my_ollama_base_url"}' where name = 'ollama_setting';
```
* 启用模型或新增模型
```
-- Enable model
update adi_ai_model set is_enable = true where name = 'gpt-3.5-turbo';
update adi_ai_model set is_enable = true where name = 'dall-e-2';
update adi_ai_model set is_enable = true where name = 'qwen-turbo';
update adi_ai_model set is_enable = true where name = 'ernie-3.5-8k-0205';
update adi_ai_model set is_enable = true where name = 'tinydolphin';
-- Add new model
INSERT INTO adi_ai_model (name, type, platform, is_enable) VALUES ('vicuna', 'text', 'ollama', true);
``` ```
* 填充搜索引擎的配置 * 填充搜索引擎的配置

View File

@ -1,6 +1,7 @@
package com.moyz.adi.chat.controller; package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.KbQaRecordDto;
import com.moyz.adi.common.dto.QAReq; import com.moyz.adi.common.dto.QAReq;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService; import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
@ -38,7 +39,7 @@ public class KnowledgeBaseQAController {
} }
@GetMapping("/record/search") @GetMapping("/record/search")
public Page<KnowledgeBaseQaRecord> list(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) { public Page<KbQaRecordDto> list(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return knowledgeBaseQaRecordService.search(kbUuid, keyword, currentPage, pageSize); return knowledgeBaseQaRecordService.search(kbUuid, keyword, currentPage, pageSize);
} }

View File

@ -102,6 +102,21 @@ public class AdiConstant {
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"}; public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
public static class ModelPlatform {
public static final String OPENAI = "openai";
public static final String DASHSCOPE = "dashscope";
public static final String QIANFAN = "qianfan";
public static final String OLLAMA = "ollama";
}
public static class ModelType {
public static final String TEXT = "text";
public static final String IMAGE = "image";
public static final String EMBEDDING = "embedding";
public static final String RERANK = "rerank";
}
public static class SearchEngineName { public static class SearchEngineName {
public static final String GOOGLE = "google"; public static final String GOOGLE = "google";
public static final String BING = "bing"; public static final String BING = "bing";

View File

@ -0,0 +1,44 @@
package com.moyz.adi.common.dto;
import com.baomidou.mybatisplus.annotation.TableField;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
@Data
public class ConvMsgDto {
@JsonIgnore
private Long id;
@Schema(title = "消息的uuid")
private String uuid;
@Schema(title = "父级消息id")
private Long parentMessageId;
@Schema(title = "对话的消息")
@TableField("remark")
private String remark;
@Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手")
private Integer messageRole;
@Schema(title = "消耗的token数量")
private Integer tokens;
@Schema(title = "创建时间")
private LocalDateTime createTime;
@Schema(title = "model id")
private Long aiModelId;
@Schema(title = "model platform name")
private String aiModelPlatform;
@Schema(title = "子级消息一般指的是AI的响应")
private List<ConvMsgDto> children;
}

View File

@ -11,5 +11,5 @@ public class ConvMsgListResp {
private String minMsgUuid; private String minMsgUuid;
private List<ConvMsgResp> msgList; private List<ConvMsgDto> msgList;
} }

View File

@ -0,0 +1,39 @@
package com.moyz.adi.common.dto;
import com.baomidou.mybatisplus.annotation.TableField;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
public class KbQaRecordDto {
@Schema(title = "uuid")
private String uuid;
@Schema(title = "知识库uuid")
private String kbUuid;
@Schema(title = "来源文档id,以逗号隔开")
private String sourceFileIds;
@Schema(title = "问题")
private String question;
@Schema(title = "最终提供给LLM的提示词")
@TableField("prompt")
private String prompt;
@Schema(title = "提供给LLM的提示词所消耗的token数量")
private Integer promptTokens;
@Schema(title = "答案")
private String answer;
@Schema(title = "答案消耗的token")
private Integer answerTokens;
@Schema(title = "ai model id")
private Long aiModelId;
@Schema(title = "ai model platform")
private String aiModelPlatform;
}

View File

@ -1,6 +1,5 @@
package com.moyz.adi.common.entity; package com.moyz.adi.common.entity;
import com.moyz.adi.common.enums.AiModelStatus;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
@ -11,15 +10,23 @@ import lombok.Data;
@Schema(title = "AiModel对象", description = "AI模型表") @Schema(title = "AiModel对象", description = "AI模型表")
public class AiModel extends BaseEntity { public class AiModel extends BaseEntity {
@Schema(title = "模型类型:text,image,embedding,rerank")
@TableField("type")
private String type;
@Schema(title = "模型名称") @Schema(title = "模型名称")
@TableField("name") @TableField("name")
private String name; private String name;
@Schema(title = "模型所属平台")
@TableField("platform")
private String platform;
@Schema(title = "说明") @Schema(title = "说明")
@TableField("remark") @TableField("remark")
private String remark; private String remark;
@Schema(title = "状态(1:正常使用,2:不可用)") @Schema(title = "状态(1:正常使用,0:不可用)")
@TableField("model_status") @TableField("is_enable")
private AiModelStatus modelStatus; private Boolean isEnable;
} }

View File

@ -58,7 +58,7 @@ public class ConversationMessage extends BaseEntity {
@TableField("understand_context_msg_pair_num") @TableField("understand_context_msg_pair_num")
private Integer understandContextMsgPairNum; private Integer understandContextMsgPairNum;
@Schema(name = "LLM name") @Schema(name = "adi_ai_model id")
@TableField("language_model_name") @TableField("ai_model_id")
private String languageModelName; private Long aiModelId;
} }

View File

@ -51,4 +51,7 @@ public class KnowledgeBaseQaRecord extends BaseEntity {
@TableField("user_id") @TableField("user_id")
private Long userId; private Long userId;
@Schema(title = "adi_ai_model id")
@TableField("ai_model_id")
private Long aiModelId;
} }

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.helper;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractImageModelService; import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.vo.ImageModelInfo; import com.moyz.adi.common.vo.ImageModelInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -35,12 +36,12 @@ public class ImageModelContext {
} }
} }
public static void addImageModelService(String modelServiceKey, AbstractImageModelService modelService) { public static void addImageModelService(AbstractImageModelService modelService) {
ImageModelInfo imageModelInfo = new ImageModelInfo(); ImageModelInfo imageModelInfo = new ImageModelInfo();
imageModelInfo.setModelService(modelService); imageModelInfo.setModelService(modelService);
imageModelInfo.setModelName(modelServiceKey); imageModelInfo.setModelName(modelService.getAiModel().getName());
imageModelInfo.setEnable(modelService.isEnabled()); imageModelInfo.setEnable(modelService.isEnabled());
NAME_TO_MODEL.put(modelServiceKey, imageModelInfo); NAME_TO_MODEL.put(modelService.getAiModel().getName(), imageModelInfo);
} }
public AbstractImageModelService getModelService() { public AbstractImageModelService getModelService() {

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.helper;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
@ -34,12 +35,12 @@ public class LLMContext {
} }
} }
public static void addLLMService(String llmServiceKey, AbstractLLMService llmService) { public static void addLLMService(AbstractLLMService llmService) {
LLMModelInfo llmModelInfo = new LLMModelInfo(); LLMModelInfo llmModelInfo = new LLMModelInfo();
llmModelInfo.setModelName(llmServiceKey); llmModelInfo.setModelName(llmService.getAiModel().getName());
llmModelInfo.setEnable(llmService.isEnabled()); llmModelInfo.setEnable(llmService.isEnabled());
llmModelInfo.setLlmService(llmService); llmModelInfo.setLlmService(llmService);
NAME_TO_MODEL.put(llmServiceKey, llmModelInfo); NAME_TO_MODEL.put(llmService.getAiModel().getName(), llmModelInfo);
} }
public AbstractLLMService getLLMService() { public AbstractLLMService getLLMService() {

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.interfaces; package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.entity.AiImage; import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
@ -24,7 +25,7 @@ public abstract class AbstractImageModelService<T> {
protected Proxy proxy; protected Proxy proxy;
protected String modelName; protected AiModel aiModel;
protected T setting; protected T setting;
@ -39,8 +40,8 @@ public abstract class AbstractImageModelService<T> {
protected ImageModel imageModel; protected ImageModel imageModel;
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz) { public AbstractImageModelService(AiModel aiModel, String settingName, Class<T> clazz) {
this.modelName = modelName; this.aiModel = aiModel;
String st = LocalCache.CONFIGS.get(settingName); String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz); setting = JsonUtil.fromJson(st, clazz);
} }
@ -88,4 +89,8 @@ public abstract class AbstractImageModelService<T> {
public abstract List<String> editImage(User user, AiImage aiImage); public abstract List<String> editImage(User user, AiImage aiImage);
public abstract List<String> createImageVariation(User user, AiImage aiImage); public abstract List<String> createImageVariation(User user, AiImage aiImage);
public AiModel getAiModel() {
return aiModel;
}
} }

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.interfaces; package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.service.RAGService; import com.moyz.adi.common.service.RAGService;
import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.JsonUtil;
@ -37,9 +38,9 @@ public abstract class AbstractLLMService<T> {
protected Proxy proxy; protected Proxy proxy;
protected String modelName; protected AiModel aiModel;
protected T setting; protected T modelPlatformSetting;
protected StreamingChatLanguageModel streamingChatLanguageModel; protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel; protected ChatLanguageModel chatLanguageModel;
@ -50,10 +51,10 @@ public abstract class AbstractLLMService<T> {
private RAGService queryCompressingRagService; private RAGService queryCompressingRagService;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) { public AbstractLLMService(AiModel aiModel, String settingName, Class<T> clazz) {
this.modelName = modelName; this.aiModel = aiModel;
String st = LocalCache.CONFIGS.get(settingName); String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz); modelPlatformSetting = JsonUtil.fromJson(st, clazz);
} }
public AbstractLLMService setProxy(Proxy proxy) { public AbstractLLMService setProxy(Proxy proxy) {
@ -61,7 +62,7 @@ public abstract class AbstractLLMService<T> {
return this; return this;
} }
public AbstractLLMService setQueryCompressingRAGService(RAGService ragService){ public AbstractLLMService setQueryCompressingRAGService(RAGService ragService) {
queryCompressingRagService = ragService; queryCompressingRagService = ragService;
return this; return this;
} }
@ -201,4 +202,8 @@ public abstract class AbstractLLMService<T> {
return chatAssistant.chat(messageId, userMessage); return chatAssistant.chat(messageId, userMessage);
} }
} }
public AiModel getAiModel() {
return aiModel;
}
} }

View File

@ -0,0 +1,51 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.mapper.AiModelMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import java.util.List;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
import static com.moyz.adi.common.util.LocalCache.MODEL_NAME_TO_OBJ;
@Slf4j
@Service
public class AiModelService extends ServiceImpl<AiModelMapper, AiModel> {
@Scheduled(fixedDelay = 5 * 60 * 1000)
public void initAll() {
List<AiModel> all = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getIsDeleted, false)
.list();
for (AiModel model : all) {
MODEL_NAME_TO_OBJ.put(model.getName(), model);
MODEL_ID_TO_OBJ.put(model.getId(), model);
}
}
public List<AiModel> listBy(String platform, String type) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getPlatform, platform)
.eq(AiModel::getType, type)
.eq(AiModel::getIsDeleted, false)
.list();
}
public AiModel getByName(String modelName) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiModel::getName, modelName)
.eq(AiModel::getIsDeleted, false)
.one();
}
public Long getIdByName(String modelName) {
AiModel aiModel = this.getByName(modelName);
return null == aiModel ? 0l : aiModel.getId();
}
}

View File

@ -61,6 +61,9 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
@Resource @Resource
private SSEEmitterHelper sseEmitterHelper; private SSEEmitterHelper sseEmitterHelper;
@Resource
private AiModelService aiModelService;
public SseEmitter sseAsk(AskReq askReq) { public SseEmitter sseAsk(AskReq askReq) {
SseEmitter sseEmitter = new SseEmitter(); SseEmitter sseEmitter = new SseEmitter();
@ -143,30 +146,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable())) { if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable())) {
sseAskParams.setMessageId(askReq.getConversationUuid()); sseAskParams.setMessageId(askReq.getConversationUuid());
} }
// List<ConversationMessage> historyMsgList = this.lambdaQuery()
// .eq(ConversationMessage::getUserId, user.getId())
// .eq(ConversationMessage::getConversationUuid, askReq.getConversationUuid())
// .orderByDesc(ConversationMessage::getId)
// .last("limit " + user.getUnderstandContextMsgPairNum() * 2)
// .list();
// if (!historyMsgList.isEmpty()) {
// ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO));
// historyMsgList.sort(Comparator.comparing(ConversationMessage::getId));
// for (ConversationMessage historyMsg : historyMsgList) {
// if (ChatMessageRoleEnum.USER.getValue().equals(historyMsg.getMessageRole())) {
// UserMessage userMessage = UserMessage.from(historyMsg.getRemark());
// chatMemory.add(userMessage);
// } else if (ChatMessageRoleEnum.SYSTEM.getValue().equals(historyMsg.getMessageRole())) {
// SystemMessage systemMessage = SystemMessage.from(historyMsg.getRemark());
// chatMemory.add(systemMessage);
// }else if (ChatMessageRoleEnum.ASSISTANT.getValue().equals(historyMsg.getMessageRole())) {
// AiMessage aiMessage = AiMessage.from(historyMsg.getRemark());
// chatMemory.add(aiMessage);
// }
// }
// sseAskParams.setChatMemory(chatMemory);
// }
// }
} }
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> { sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
@ -228,7 +207,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
aiAnswer.setTokens(answerMeta.getTokens()); aiAnswer.setTokens(answerMeta.getTokens());
aiAnswer.setParentMessageId(promptMsg.getId()); aiAnswer.setParentMessageId(promptMsg.getId());
aiAnswer.setSecretKeyType(secretKeyType); aiAnswer.setSecretKeyType(secretKeyType);
aiAnswer.setLanguageModelName(askReq.getModelName()); aiAnswer.setAiModelId(aiModelService.getIdByName(askReq.getModelName()));
baseMapper.insert(aiAnswer); baseMapper.insert(aiAnswer);
calcTodayCost(user, conversation, questionMeta, answerMeta); calcTodayCost(user, conversation, questionMeta, answerMeta);

View File

@ -5,7 +5,8 @@ import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.ConvDto; import com.moyz.adi.common.dto.ConvDto;
import com.moyz.adi.common.dto.ConvEditReq; import com.moyz.adi.common.dto.ConvEditReq;
import com.moyz.adi.common.dto.ConvMsgListResp; import com.moyz.adi.common.dto.ConvMsgListResp;
import com.moyz.adi.common.dto.ConvMsgResp; import com.moyz.adi.common.dto.ConvMsgDto;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.Conversation; import com.moyz.adi.common.entity.Conversation;
import com.moyz.adi.common.entity.ConversationMessage; import com.moyz.adi.common.entity.ConversationMessage;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
@ -19,9 +20,12 @@ import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_EXIST; import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_EXIST;
import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_NOT_EXIST; import static com.moyz.adi.common.enums.ErrorEnum.A_CONVERSATION_NOT_EXIST;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
import static com.moyz.adi.common.util.LocalCache.MODEL_NAME_TO_OBJ;
@Slf4j @Slf4j
@Service @Service
@ -33,6 +37,9 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
@Resource @Resource
private ConversationMessageService conversationMessageService; private ConversationMessageService conversationMessageService;
@Resource
private AiModelService aiModelService;
public List<ConvDto> listByUser() { public List<ConvDto> listByUser() {
List<Conversation> list = this.lambdaQuery() List<Conversation> list = this.lambdaQuery()
.eq(Conversation::getUserId, ThreadContext.getCurrentUserId()) .eq(Conversation::getUserId, ThreadContext.getCurrentUserId())
@ -81,7 +88,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
return b; return b;
}).getUuid(); }).getUuid();
//Wrap question content //Wrap question content
List<ConvMsgResp> userMessages = MPPageUtil.convertToList(questions, ConvMsgResp.class); List<ConvMsgDto> userMessages = MPPageUtil.convertToList(questions, ConvMsgDto.class);
ConvMsgListResp result = new ConvMsgListResp(minUuid, userMessages); ConvMsgListResp result = new ConvMsgListResp(minUuid, userMessages);
//Wrap answer content //Wrap answer content
@ -95,9 +102,14 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
//Fill AI answer to the request of user //Fill AI answer to the request of user
result.getMsgList().forEach(item -> { result.getMsgList().forEach(item -> {
List<ConvMsgResp> children = MPPageUtil.convertToList(idToMessages.get(item.getId()), ConvMsgResp.class); List<ConvMsgDto> children = MPPageUtil.convertToList(idToMessages.get(item.getId()), ConvMsgDto.class);
if (children.size() > 1) { if (children.size() > 1) {
children = children.stream().sorted(Comparator.comparing(ConvMsgResp::getCreateTime).reversed()).collect(Collectors.toList()); children = children.stream().sorted(Comparator.comparing(ConvMsgDto::getCreateTime).reversed()).collect(Collectors.toList());
}
for (ConvMsgDto convMsgDto : children) {
AiModel aiModel = MODEL_ID_TO_OBJ.get(convMsgDto.getAiModelId());
convMsgDto.setAiModelPlatform(null == aiModel ? "" : aiModel.getPlatform());
} }
item.setChildren(children); item.setChildren(children);
}); });

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.DashScopeSetting; import com.moyz.adi.common.vo.DashScopeSetting;
@ -11,8 +12,6 @@ import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.net.Proxy;
import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET; import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
/** /**
@ -21,23 +20,23 @@ import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
@Slf4j @Slf4j
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> { public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
public DashScopeLLMService(String modelName) { public DashScopeLLMService(AiModel aiModel) {
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class); super(aiModel, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class);
} }
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getApiKey()); return StringUtils.isNotBlank(modelPlatformSetting.getApiKey()) && aiModel.getIsEnable();
} }
@Override @Override
protected StreamingChatLanguageModel buildStreamingChatLLM() { protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) { if (StringUtils.isBlank(modelPlatformSetting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET); throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
} }
return QwenStreamingChatModel.builder() return QwenStreamingChatModel.builder()
.apiKey(setting.getApiKey()) .apiKey(modelPlatformSetting.getApiKey())
.modelName(modelName) .modelName(aiModel.getName())
.build(); .build();
} }
@ -48,12 +47,12 @@ public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) { if (StringUtils.isBlank(modelPlatformSetting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET); throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
} }
return QwenChatModel.builder() return QwenChatModel.builder()
.apiKey(setting.getApiKey()) .apiKey(modelPlatformSetting.getApiKey())
.modelName(modelName) .modelName(aiModel.getName())
.build(); .build();
} }

View File

@ -1,19 +1,24 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.helper.ImageModelContext; import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.searchengine.GoogleSearchEngine; import com.moyz.adi.common.searchengine.GoogleSearchEngine;
import com.moyz.adi.common.searchengine.SearchEngineContext; import com.moyz.adi.common.searchengine.SearchEngineContext;
import dev.langchain4j.model.openai.OpenAiModelName;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Proxy; import java.net.Proxy;
import java.util.List;
import java.util.function.Function;
@Slf4j @Slf4j
@Service @Service
@ -28,6 +33,9 @@ public class Initializer {
@Value("${adi.proxy.http-port:0}") @Value("${adi.proxy.http-port:0}")
protected int proxyHttpPort; protected int proxyHttpPort;
@Resource
private AiModelService aiModelService;
@Resource @Resource
private SysConfigService sysConfigService; private SysConfigService sysConfigService;
@ -37,52 +45,53 @@ public class Initializer {
@PostConstruct @PostConstruct
public void init() { public void init() {
sysConfigService.reload(); sysConfigService.reload();
aiModelService.initAll();
Proxy proxy = null; Proxy proxy;
if (proxyEnable) { if (proxyEnable) {
proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
} else {
proxy = null;
} }
//openai //openai
String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING); initLLMService(AdiConstant.ModelPlatform.OPENAI, (model) -> new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
if (openaiModels.length == 0) {
log.warn("openai service is disabled");
}
for (String model : openaiModels) {
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy).setQueryCompressingRAGService(queryCompressingRagService));
}
//dashscope //dashscope
String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING); initLLMService(AdiConstant.ModelPlatform.DASHSCOPE, (model) -> new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
if (dashscopeModels.length == 0) {
log.warn("dashscope service is disabled");
}
for (String model : dashscopeModels) {
LLMContext.addLLMService(model, new DashScopeLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
//qianfan //qianfan
String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING); initLLMService(AdiConstant.ModelPlatform.QIANFAN, (model) -> new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
if (qianfanModels.length == 0) {
log.warn("qianfan service is disabled");
}
for (String model : qianfanModels) {
LLMContext.addLLMService(model, new QianFanLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
//ollama //ollama
String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING); initLLMService(AdiConstant.ModelPlatform.OLLAMA, (model) -> new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
if (ollamaModels.length == 0) {
log.warn("ollama service is disabled");
}
for (String model : ollamaModels) {
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model).setQueryCompressingRAGService(queryCompressingRagService));
}
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy)); //openai image model
initImageModelService(AdiConstant.ModelPlatform.OPENAI, (model) -> new OpenAiImageModelService(model).setProxy(proxy));
//search engine //search engine
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy)); SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
} }
private void initLLMService(String platform, Function<AiModel, AbstractLLMService> function) {
List<AiModel> models = aiModelService.listBy(platform, AdiConstant.ModelType.TEXT);
if (CollectionUtils.isEmpty(models)) {
log.warn("{} service is disabled", platform);
}
for (AiModel model : models) {
LLMContext.addLLMService(function.apply(model));
}
}
private void initImageModelService(String platform, Function<AiModel, AbstractImageModelService> function) {
List<AiModel> models = aiModelService.listBy(platform, AdiConstant.ModelType.IMAGE);
if (CollectionUtils.isEmpty(models)) {
log.warn("{} service is disabled", platform);
}
for (AiModel model : models) {
ImageModelContext.addImageModelService(function.apply(model));
}
}
} }

View File

@ -5,11 +5,15 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.KbQaRecordDto;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.KnowledgeBase; import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord; import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper; import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
import com.moyz.adi.common.util.MPPageUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -17,12 +21,16 @@ import org.springframework.stereotype.Service;
import java.util.UUID; import java.util.UUID;
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND; import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
import static com.moyz.adi.common.util.LocalCache.MODEL_ID_TO_OBJ;
@Slf4j @Slf4j
@Service @Service
public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRecordMapper, KnowledgeBaseQaRecord> { public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRecordMapper, KnowledgeBaseQaRecord> {
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) { @Resource
private AiModelService aiModelService;
public Page<KbQaRecordDto> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid); wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false); wrapper.eq(KnowledgeBaseQaRecord::getIsDeleted, false);
@ -33,7 +41,15 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
wrapper.like(KnowledgeBaseQaRecord::getQuestion, keyword); wrapper.like(KnowledgeBaseQaRecord::getQuestion, keyword);
} }
wrapper.orderByDesc(KnowledgeBaseQaRecord::getUpdateTime); wrapper.orderByDesc(KnowledgeBaseQaRecord::getUpdateTime);
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper); Page<KnowledgeBaseQaRecord> page = baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
Page<KbQaRecordDto> result = new Page<>();
MPPageUtil.convertToPage(page, result, KbQaRecordDto.class, (t1, t2) -> {
AiModel aiModel = MODEL_ID_TO_OBJ.get(t1.getAiModelId());
t2.setAiModelPlatform(null == aiModel ? "" : aiModel.getPlatform());
return t2;
});
return result;
} }
/** /**
@ -45,9 +61,10 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
* @param promptTokens 提示词消耗的token * @param promptTokens 提示词消耗的token
* @param answer 答案 * @param answer 答案
* @param answerTokens 答案消耗的token * @param answerTokens 答案消耗的token
* @param modelName ai model name
* @return * @return
*/ */
public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens) { public KnowledgeBaseQaRecord createNewRecord(User user, KnowledgeBase knowledgeBase, String question, String prompt, int promptTokens, String answer, int answerTokens, String modelName) {
String uuid = UUID.randomUUID().toString().replace("-", ""); String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord(); KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
newObj.setKbId(knowledgeBase.getId()); newObj.setKbId(knowledgeBase.getId());
@ -59,6 +76,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
newObj.setPromptTokens(promptTokens); newObj.setPromptTokens(promptTokens);
newObj.setAnswer(answer); newObj.setAnswer(answer);
newObj.setAnswerTokens(answerTokens); newObj.setAnswerTokens(answerTokens);
newObj.setAiModelId(aiModelService.getIdByName(modelName));
baseMapper.insert(newObj); baseMapper.insert(newObj);
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();

View File

@ -244,7 +244,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
int inputTokenCount = ar.tokenUsage().inputTokenCount(); int inputTokenCount = ar.tokenUsage().inputTokenCount();
int outputTokenCount = ar.tokenUsage().outputTokenCount(); int outputTokenCount = ar.tokenUsage().outputTokenCount();
userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount); userDayCostService.appendCostToUser(ThreadContext.getCurrentUser(), inputTokenCount + outputTokenCount);
return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount); return knowledgeBaseQaRecordService.createNewRecord(ThreadContext.getCurrentUser(), knowledgeBase, question, responsePair.getLeft(), inputTokenCount, ar.content().text(), outputTokenCount, modelName);
} }
public SseEmitter sseAsk(String kbUuid, QAReq req) { public SseEmitter sseAsk(String kbUuid, QAReq req) {
@ -328,7 +328,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
sseAskParams.setUserMessage(promptText); sseAskParams.setUserMessage(promptText);
sseAskParams.setModelName(req.getModelName()); sseAskParams.setModelName(req.getModelName());
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> { sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens()); knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens(), req.getModelName());
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens()); userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
}); });
} }

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.OllamaSetting; import com.moyz.adi.common.vo.OllamaSetting;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@ -12,20 +13,20 @@ import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTI
public class OllamaLLMService extends AbstractLLMService<OllamaSetting> { public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
public OllamaLLMService(String modelName) { public OllamaLLMService(AiModel aiModel) {
super(modelName, OLLAMA_SETTING, OllamaSetting.class); super(aiModel, OLLAMA_SETTING, OllamaSetting.class);
} }
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getBaseUrl()); return StringUtils.isNotBlank(modelPlatformSetting.getBaseUrl()) && aiModel.getIsEnable();
} }
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
return OllamaChatModel.builder() return OllamaChatModel.builder()
.baseUrl(setting.getBaseUrl()) .baseUrl(modelPlatformSetting.getBaseUrl())
.modelName(modelName) .modelName(aiModel.getName())
.temperature(0.0) .temperature(0.0)
.build(); .build();
} }
@ -33,8 +34,8 @@ public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
@Override @Override
protected StreamingChatLanguageModel buildStreamingChatLLM() { protected StreamingChatLanguageModel buildStreamingChatLLM() {
return OllamaStreamingChatModel.builder() return OllamaStreamingChatModel.builder()
.baseUrl(setting.getBaseUrl()) .baseUrl(modelPlatformSetting.getBaseUrl())
.modelName(modelName) .modelName(aiModel.getName())
.build(); .build();
} }

View File

@ -3,6 +3,7 @@ package com.moyz.adi.common.service;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiImage; import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
@ -45,13 +46,13 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
@Resource @Resource
private ObjectMapper objectMapper; private ObjectMapper objectMapper;
public OpenAiImageModelService(String modelName) { public OpenAiImageModelService(AiModel aiModel) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class); super(aiModel, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
} }
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey()); return StringUtils.isNotBlank(setting.getSecretKey()) && aiModel.getIsEnable();
} }
@Override @Override
@ -60,7 +61,7 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
} }
OpenAiImageModel.OpenAiImageModelBuilder builder = OpenAiImageModel.builder() OpenAiImageModel.OpenAiImageModelBuilder builder = OpenAiImageModel.builder()
.modelName(modelName) .modelName(aiModel.getName())
.apiKey(setting.getSecretKey()) .apiKey(setting.getSecretKey())
.user(user.getUuid()) .user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL) .responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
@ -27,24 +28,24 @@ import java.time.temporal.ChronoUnit;
@Accessors(chain = true) @Accessors(chain = true)
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> { public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
public OpenAiLLMService(String modelName) { public OpenAiLLMService(AiModel model) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class); super(model, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
} }
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey()); return StringUtils.isNotBlank(modelPlatformSetting.getSecretKey()) && aiModel.getIsEnable();
} }
@Override @Override
protected StreamingChatLanguageModel buildStreamingChatLLM() { protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) { if (StringUtils.isBlank(modelPlatformSetting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
} }
OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel
.builder() .builder()
.modelName(modelName) .modelName(aiModel.getName())
.apiKey(setting.getSecretKey()) .apiKey(modelPlatformSetting.getSecretKey())
.timeout(Duration.of(60, ChronoUnit.SECONDS)); .timeout(Duration.of(60, ChronoUnit.SECONDS));
if (null != proxy) { if (null != proxy) {
builder.proxy(proxy); builder.proxy(proxy);
@ -54,8 +55,8 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
@Override @Override
protected String parseError(Object error) { protected String parseError(Object error) {
if(error instanceof OpenAiHttpException){ if (error instanceof OpenAiHttpException) {
OpenAiHttpException openAiHttpException = (OpenAiHttpException)error; OpenAiHttpException openAiHttpException = (OpenAiHttpException) error;
OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class); OpenAiError openAiError = JsonUtil.fromJson(openAiHttpException.getMessage(), OpenAiError.class);
return openAiError.getError().getMessage(); return openAiError.getError().getMessage();
} }
@ -64,10 +65,10 @@ public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) { if (StringUtils.isBlank(modelPlatformSetting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
} }
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(setting.getSecretKey()); OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(modelPlatformSetting.getSecretKey());
if (null != proxy) { if (null != proxy) {
builder.proxy(proxy); builder.proxy(proxy);
} }

View File

@ -1,6 +1,7 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.interfaces.AbstractLLMService; import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.QianFanSetting; import com.moyz.adi.common.vo.QianFanSetting;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@ -18,35 +19,35 @@ import org.apache.commons.lang3.StringUtils;
@Accessors(chain = true) @Accessors(chain = true)
public class QianFanLLMService extends AbstractLLMService<QianFanSetting> { public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
public QianFanLLMService(String modelName) { public QianFanLLMService(AiModel aiModel) {
super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class); super(aiModel, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class);
} }
@Override @Override
public boolean isEnabled() { public boolean isEnabled() {
return StringUtils.isNoneBlank(setting.getApiKey(), setting.getSecretKey()); return StringUtils.isNoneBlank(modelPlatformSetting.getApiKey(), modelPlatformSetting.getSecretKey()) && aiModel.getIsEnable();
} }
@Override @Override
protected ChatLanguageModel buildChatLLM() { protected ChatLanguageModel buildChatLLM() {
return QianfanChatModel.builder() return QianfanChatModel.builder()
.modelName(modelName) .modelName(aiModel.getName())
.temperature(0.7) .temperature(0.7)
.topP(1.0) .topP(1.0)
.maxRetries(1) .maxRetries(1)
.apiKey(setting.getApiKey()) .apiKey(modelPlatformSetting.getApiKey())
.secretKey(setting.getSecretKey()) .secretKey(modelPlatformSetting.getSecretKey())
.build(); .build();
} }
@Override @Override
protected StreamingChatLanguageModel buildStreamingChatLLM() { protected StreamingChatLanguageModel buildStreamingChatLLM() {
return QianfanStreamingChatModel.builder() return QianfanStreamingChatModel.builder()
.modelName(modelName) .modelName(aiModel.getName())
.temperature(0.7) .temperature(0.7)
.topP(1.0) .topP(1.0)
.apiKey(setting.getApiKey()) .apiKey(modelPlatformSetting.getApiKey())
.secretKey(setting.getSecretKey()) .secretKey(modelPlatformSetting.getSecretKey())
.build(); .build();
} }

View File

@ -1,14 +1,19 @@
package com.moyz.adi.common.util; package com.moyz.adi.common.util;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.vo.RequestRateLimit; import com.moyz.adi.common.vo.RequestRateLimit;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class LocalCache { public class LocalCache {
public static final Map<String, String> CONFIGS = new HashMap<>(); public static final Map<String, String> CONFIGS = new ConcurrentHashMap<>();
public static RequestRateLimit TEXT_RATE_LIMIT_CONFIG; public static RequestRateLimit TEXT_RATE_LIMIT_CONFIG;
public static RequestRateLimit IMAGE_RATE_LIMIT_CONFIG; public static RequestRateLimit IMAGE_RATE_LIMIT_CONFIG;
public static Map<String, AiModel> MODEL_NAME_TO_OBJ = new ConcurrentHashMap<>();
public static Map<Long, AiModel> MODEL_ID_TO_OBJ = new ConcurrentHashMap<>();
} }

View File

@ -49,26 +49,39 @@ COMMENT ON COLUMN public.adi_ai_image.is_deleted IS 'Flag indicating whether the
CREATE TABLE public.adi_ai_model CREATE TABLE public.adi_ai_model
( (
id bigserial primary key, id bigserial primary key,
name character varying(45) DEFAULT ''::character varying NOT NULL, name varchar(45) default '' not null,
remark character varying(1000), type varchar(45) default 'llm' not null,
model_status smallint DEFAULT '1'::smallint NOT NULL, remark varchar(1000) default '',
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, platform varchar(45) default '' not null,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, max_tokens int default 0 not null,
is_deleted boolean DEFAULT false NOT NULL, is_enable boolean default false NOT NULL,
CONSTRAINT adi_ai_model_model_status_check CHECK ((model_status = ANY (ARRAY [1, 2]))) create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
); );
COMMENT ON TABLE public.adi_ai_model IS 'ai模型'; COMMENT ON TABLE public.adi_ai_model IS 'ai模型';
COMMENT ON COLUMN public.adi_ai_model.type IS 'The type of the AI model,eg: text,image,embedding,rerank';
COMMENT ON COLUMN public.adi_ai_model.name IS 'The name of the AI model'; COMMENT ON COLUMN public.adi_ai_model.name IS 'The name of the AI model';
COMMENT ON COLUMN public.adi_ai_model.remark IS 'Additional remarks about the AI model'; COMMENT ON COLUMN public.adi_ai_model.remark IS 'Additional remarks about the AI model';
COMMENT ON COLUMN public.adi_ai_model.model_status IS '1: Normal usage, 2: Not available'; COMMENT ON COLUMN public.adi_ai_model.platform IS 'eg: openai,dashscope,qianfan,ollama';
COMMENT ON COLUMN public.adi_ai_model.max_tokens IS 'The maximum number of tokens that can be generated';
COMMENT ON COLUMN public.adi_ai_model.is_enable IS '1: Normal usage, 0: Not available';
COMMENT ON COLUMN public.adi_ai_model.create_time IS 'Timestamp of record creation'; COMMENT ON COLUMN public.adi_ai_model.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_ai_model.update_time IS 'Timestamp of record last update, automatically updated on each update'; COMMENT ON COLUMN public.adi_ai_model.update_time IS 'Timestamp of record last update, automatically updated on each update';
INSERT INTO adi_ai_model (name, type, platform, max_tokens, is_enable)
VALUES ('gpt-3.5-turbo', 'text', 'openai', 2048, false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('dall-e-2', 'image', 'openai', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('qwen-turbo', 'text', 'dashscope', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('ernie-3.5-8k-0205', 'text', 'qianfan', false);
INSERT INTO adi_ai_model (name, type, platform, is_enable)
VALUES ('tinydolphin', 'text', 'ollama', false);
CREATE TABLE public.adi_conversation CREATE TABLE public.adi_conversation
( (
id bigserial primary key, id bigserial primary key,
@ -103,8 +116,8 @@ CREATE TABLE public.adi_conversation_message
uuid character varying(32) DEFAULT ''::character varying NOT NULL, uuid character varying(32) DEFAULT ''::character varying NOT NULL,
message_role integer DEFAULT 1 NOT NULL, message_role integer DEFAULT 1 NOT NULL,
tokens integer DEFAULT 0 NOT NULL, tokens integer DEFAULT 0 NOT NULL,
language_model_name character varying(32) DEFAULT ''::character varying NOT NULL,
user_id bigint DEFAULT '0'::bigint NOT NULL, user_id bigint DEFAULT '0'::bigint NOT NULL,
ai_model_id bigint default 0 not null,
secret_key_type integer DEFAULT 1 NOT NULL, secret_key_type integer DEFAULT 1 NOT NULL,
understand_context_msg_pair_num integer DEFAULT 0 NOT NULL, understand_context_msg_pair_num integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
@ -131,6 +144,8 @@ COMMENT ON COLUMN public.adi_conversation_message.language_model_name IS 'LLM na
COMMENT ON COLUMN public.adi_conversation_message.user_id IS '用户ID'; COMMENT ON COLUMN public.adi_conversation_message.user_id IS '用户ID';
COMMENT ON COLUMN public.adi_conversation_message.ai_model_id IS 'adi_ai_model id';
COMMENT ON COLUMN public.adi_conversation_message.secret_key_type IS '加密密钥类型'; COMMENT ON COLUMN public.adi_conversation_message.secret_key_type IS '加密密钥类型';
COMMENT ON COLUMN public.adi_conversation_message.understand_context_msg_pair_num IS '上下文消息对数量'; COMMENT ON COLUMN public.adi_conversation_message.understand_context_msg_pair_num IS '上下文消息对数量';
@ -381,6 +396,12 @@ CREATE TRIGGER trigger_user_day_cost_update_time
FOR EACH ROW FOR EACH ROW
EXECUTE PROCEDURE update_modified_column(); EXECUTE PROCEDURE update_modified_column();
create trigger trigger_ai_model
before update
on adi_ai_model
for each row
execute procedure update_modified_column();
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('openai_setting', '{"secret_key":"","models":[]}'); VALUES ('openai_setting', '{"secret_key":"","models":[]}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
@ -549,6 +570,7 @@ create table adi_knowledge_base_qa_record
answer_tokens integer DEFAULT 0 NOT NULL, answer_tokens integer DEFAULT 0 NOT NULL,
source_file_ids varchar(500) default ''::character varying not null, source_file_ids varchar(500) default ''::character varying not null,
user_id bigint default '0' NOT NULL, user_id bigint default '0' NOT NULL,
ai_model_id bigint default 0 not null,
create_time timestamp default CURRENT_TIMESTAMP not null, create_time timestamp default CURRENT_TIMESTAMP not null,
update_time timestamp default CURRENT_TIMESTAMP not null, update_time timestamp default CURRENT_TIMESTAMP not null,
is_deleted boolean default false not null is_deleted boolean default false not null