diff --git a/README.md b/README.md index ad06650..a50e42c 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,8 @@ * 图片生成(文生图、修图、图生图) * 提示词 * 额度控制 -* 自定义openai secret_key * 基于大模型的知识库(RAG) - -![1691585301627](image/README/1691585301627.png "登录注册") +* 多模型随意切换 **AI聊天:** ![1691583184761](image/README/1691583184761.png) @@ -31,14 +29,18 @@ ![kb03](image/README/kb03.png) -体验网址:[http://www.aideepin.com](http://www.aideepin.com/) +### 体验网址 +[http://www.aideepin.com](http://www.aideepin.com/) -接入的模型:ChatGPT 3.5,DALL-E 2 - -该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) +### 接入的模型: +* ChatGPT 3.5 +* 通义千问 +* DALL-E 2 ### 技术 +该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) + 后端: jdk17 @@ -61,10 +63,11 @@ vue3+typescript+pnpm * 创建数据库aideepin * 执行docs/create.sql -* 填充openai的secret\_key +* 填充openai的secretKey 或者 灵积模型的apiKey ```plaintext -update adi_sys_config set value = 'my_chatgpt_secret_key' where name = 'secret_key' +update adi_sys_config set value = '{"secret_key":"my_openai_secret_key"}' where name = 'openai_setting'; +update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name = 'dashscope_setting'; ``` * 修改配置文件 diff --git a/adi-bootstrap/src/main/resources/application-dev.yml b/adi-bootstrap/src/main/resources/application-dev.yml index ec3063a..27199e5 100644 --- a/adi-bootstrap/src/main/resources/application-dev.yml +++ b/adi-bootstrap/src/main/resources/application-dev.yml @@ -25,17 +25,14 @@ logging: file: path: D:/data/logs -openai: +adi: + frontend-url: http://localhost:1002 + backend-url: http://localhost:1002/api proxy: enable: true host: 127.0.0.1 http-port: 1087 -adi: - frontend-url: http://localhost:1002 - backend-url: http://localhost:1002/api - - local: files: D:/data/files/ images: D:/data/images/ diff --git a/adi-bootstrap/src/main/resources/application.yml b/adi-bootstrap/src/main/resources/application.yml index 7a65237..4174f9f 100644 --- a/adi-bootstrap/src/main/resources/application.yml +++ b/adi-bootstrap/src/main/resources/application.yml @@ -53,6 +53,10 @@ logging: adi: frontend-url: http://www.aideepin.com backend-url: http://www.aideepin.com/api + proxy: + enable: false + host: 127.0.0.1 + http-port: 1087 local: files: /data/files/ diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseItemController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseItemController.java index 560d763..9307097 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseItemController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseItemController.java @@ -39,6 +39,12 @@ public class KnowledgeBaseItemController { .one(); } + /** + * 知识点向量化 + * + * @param uuid 知识点uuid + * @return + */ @PostMapping("/embedding/{uuid}") public boolean embedding(@PathVariable String uuid) { return knowledgeBaseItemService.checkAndEmbedding(uuid); diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java index bf30ac5..82c956b 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java @@ -25,7 +25,7 @@ public class KnowledgeBaseQAController { @PostMapping("/ask/{kbUuid}") public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) { - return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion()); + return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName()); } @GetMapping("/record/search") diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/ModelController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/ModelController.java new file mode 100644 index 0000000..31c41a1 --- /dev/null +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/ModelController.java @@ -0,0 +1,29 @@ +package com.moyz.adi.chat.controller; + +import com.moyz.adi.common.helper.ImageModelContext; +import com.moyz.adi.common.helper.LLMContext; +import com.moyz.adi.common.vo.ImageModelInfo; +import com.moyz.adi.common.vo.LLMModelInfo; +import io.swagger.v3.oas.annotations.Operation; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.List; +import java.util.stream.Collectors; + +@RestController +@RequestMapping("/model") +public class ModelController { + @Operation(summary = "支持的大语言模型列表") + @GetMapping(value = "/llms") + public List llms() { + return LLMContext.NAME_TO_MODEL.values().stream().collect(Collectors.toList()); + } + + @Operation(summary = "支持的图片模型列表") + @GetMapping(value = "/imageModels") + public List imageModels() { + return ImageModelContext.NAME_TO_MODEL.values().stream().collect(Collectors.toList()); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java b/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java index ba830d2..caaf127 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java +++ b/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java @@ -65,7 +65,8 @@ public class AdiConstant { } public static class SysConfigKey { - public static final String SECRET_KEY = "secret_key"; + public static final String OPENAI_SETTING = "openai_setting"; + public static final String DASHSCOPE_SETTING = "dashscope_setting"; public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit"; public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit"; public static final String CONVERSATION_MAX_NUM = "conversation_max_num"; diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/AskReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/AskReq.java index 445de3e..c7cde9a 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/dto/AskReq.java +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/AskReq.java @@ -21,4 +21,6 @@ public class AskReq { * If not empty, it means will request AI with the exist prompt, param {@code prompt} is ignored */ private String regenerateQuestionUuid; + + private String modelName; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/EditImageReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/EditImageReq.java index 646ad3b..a540afe 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/dto/EditImageReq.java +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/EditImageReq.java @@ -19,4 +19,6 @@ public class EditImageReq { @Min(1) @Max(10) private int number; + + private String modelName; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/GenerateImageReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/GenerateImageReq.java index ab9414e..5baa320 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/dto/GenerateImageReq.java +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/GenerateImageReq.java @@ -14,4 +14,6 @@ public class GenerateImageReq { @Min(1) @Max(10) private int number; + + private String modelName; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/QAReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/QAReq.java index 15e7bfb..55f9f15 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/dto/QAReq.java +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/QAReq.java @@ -10,4 +10,6 @@ public class QAReq { @NotBlank private String question; + + private String modelName; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/VariationImageReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/VariationImageReq.java index 9893df3..1a4c3dc 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/dto/VariationImageReq.java +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/VariationImageReq.java @@ -15,4 +15,6 @@ public class VariationImageReq { @Min(1) @Max(10) private int number; + + private String modelName; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseEmbedding.java b/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseEmbedding.java index 59cc4c2..caece7e 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseEmbedding.java +++ b/adi-common/src/main/java/com/moyz/adi/common/entity/KnowledgeBaseEmbedding.java @@ -1,6 +1,8 @@ package com.moyz.adi.common.entity; +import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.pgvector.PGvector; import io.swagger.v3.oas.annotations.media.Schema; @@ -9,10 +11,10 @@ import lombok.Data; @Data @TableName("adi_knowledge_base_embedding") @Schema(title = "知识库-嵌入实体", description = "知识库嵌入表") -public class KnowledgeBaseEmbedding extends BaseEntity { +public class KnowledgeBaseEmbedding{ - @Schema(title = "embedding uuid") - @TableField("embedding") + @Schema(title = "embedding_id") + @TableId(value = "embedding_id", type = IdType.AUTO) private String embeddingId; @Schema(title = "embedding") diff --git a/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java b/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java index 4a9d1e7..0aaffb1 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java +++ b/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java @@ -14,7 +14,7 @@ public enum ErrorEnum { A_IMAGE_SIZE_ERROR("A0010", "图片尺寸不对"), A_FILE_NOT_EXIST("A0011", "文件不存在"), A_DRAWING("A0012", "作图还未完成"), - A_REGISTER_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"), + A_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"), A_FIND_PASSWORD_CODE_ERROR("A0014", "重置码已过期或不存在"), A_USER_WAIT_CONFIRM("A0015", "用户未激活"), A_USER_NOT_AUTH("A0016", "用户无权限"), @@ -29,7 +29,8 @@ public enum ErrorEnum { B_FIND_IMAGE_404("B0005", "无法找到图片"), B_DAILY_QUOTA_USED("B0006", "今天额度已经用完"), B_MONTHLY_QUOTA_USED("B0007", "当月额度已经用完"), - + B_LLM_NOT_SUPPORT("B0008", "LLM不支持该功能"), + B_LLM_SECRET_KEY_NOT_SET("B0009", "LLM的secret key没设置"), B_MESSAGE_NOT_FOUND("B0008", "消息不存在"); private String code; diff --git a/adi-common/src/main/java/com/moyz/adi/common/filter/TokenFilter.java b/adi-common/src/main/java/com/moyz/adi/common/filter/TokenFilter.java index 7bc59ff..f387b6d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/filter/TokenFilter.java +++ b/adi-common/src/main/java/com/moyz/adi/common/filter/TokenFilter.java @@ -29,7 +29,8 @@ import static org.springframework.http.HttpHeaders.AUTHORIZATION; public class TokenFilter extends OncePerRequestFilter { public static final String[] EXCLUDE_API = { - "/auth/" + "/auth/", + "/model/" }; @Resource diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/ImageModelContext.java b/adi-common/src/main/java/com/moyz/adi/common/helper/ImageModelContext.java new file mode 100644 index 0000000..a26eb4a --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/ImageModelContext.java @@ -0,0 +1,49 @@ +package com.moyz.adi.common.helper; + +import com.moyz.adi.common.interfaces.AbstractImageModelService; +import com.moyz.adi.common.vo.ImageModelInfo; +import lombok.extern.slf4j.Slf4j; + +import java.util.HashMap; +import java.util.Map; + +import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2; + +/** + * image model service上下文类(策略模式) + */ +@Slf4j +public class ImageModelContext { + + /** + * AI图片模型 + */ + public static final Map NAME_TO_MODEL = new HashMap<>(); + + private AbstractImageModelService modelService; + + public ImageModelContext() { + modelService = NAME_TO_MODEL.get(DALL_E_2).getModelService(); + } + + public ImageModelContext(String modelName) { + if (null == NAME_TO_MODEL.get(modelName)) { + log.warn("︿︿︿ Can not find {}, use the default model DALL_E_2 ︿︿︿", modelName); + modelService = NAME_TO_MODEL.get(DALL_E_2).getModelService(); + } else { + modelService = NAME_TO_MODEL.get(modelName).getModelService(); + } + } + + public static void addImageModelService(String modelName, AbstractImageModelService modelService) { + ImageModelInfo imageModelInfo = new ImageModelInfo(); + imageModelInfo.setModelService(modelService); + imageModelInfo.setModelName(modelName); + imageModelInfo.setEnable(modelService.isEnabled()); + NAME_TO_MODEL.put(modelName, imageModelInfo); + } + + public AbstractImageModelService getModelService() { + return modelService; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/LLMContext.java b/adi-common/src/main/java/com/moyz/adi/common/helper/LLMContext.java new file mode 100644 index 0000000..c7b6ebf --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/LLMContext.java @@ -0,0 +1,44 @@ +package com.moyz.adi.common.helper; + +import com.moyz.adi.common.interfaces.AbstractLLMService; +import com.moyz.adi.common.vo.LLMModelInfo; +import lombok.extern.slf4j.Slf4j; + +import java.util.HashMap; +import java.util.Map; + +import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; + +/** + * llmService上下文类(策略模式) + */ +@Slf4j +public class LLMContext { + public static final Map NAME_TO_MODEL = new HashMap<>(); + private AbstractLLMService llmService; + + public LLMContext() { + llmService = NAME_TO_MODEL.get(GPT_3_5_TURBO).getLlmService(); + } + + public LLMContext(String modelName) { + if (null == NAME_TO_MODEL.get(modelName)) { + log.warn("︿︿︿ Can not find {}, use the default model GPT_3_5_TURBO ︿︿︿", modelName); + llmService = NAME_TO_MODEL.get(GPT_3_5_TURBO).getLlmService(); + } else { + llmService = NAME_TO_MODEL.get(modelName).getLlmService(); + } + } + + public static void addLLMService(String modelName, AbstractLLMService llmService) { + LLMModelInfo llmModelInfo = new LLMModelInfo(); + llmModelInfo.setModelName(modelName); + llmModelInfo.setEnable(llmService.isEnabled()); + llmModelInfo.setLlmService(llmService); + NAME_TO_MODEL.put(modelName, llmModelInfo); + } + + public AbstractLLMService getLLMService() { + return llmService; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/OpenAiHelper.java b/adi-common/src/main/java/com/moyz/adi/common/helper/OpenAiHelper.java deleted file mode 100644 index f8bcbe8..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/helper/OpenAiHelper.java +++ /dev/null @@ -1,260 +0,0 @@ -package com.moyz.adi.common.helper; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.moyz.adi.common.base.ThreadContext; -import com.moyz.adi.common.entity.AiImage; -import com.moyz.adi.common.entity.User; -import com.moyz.adi.common.enums.ErrorEnum; -import com.moyz.adi.common.exception.BaseException; -import com.moyz.adi.common.interfaces.IChatAssistant; -import com.moyz.adi.common.service.FileService; -import com.moyz.adi.common.service.SysConfigService; -import com.moyz.adi.common.util.ImageUtil; -import com.moyz.adi.common.util.JsonUtil; -import com.moyz.adi.common.util.TriConsumer; -import com.moyz.adi.common.vo.AnswerMeta; -import com.moyz.adi.common.vo.ChatMeta; -import com.moyz.adi.common.vo.QuestionMeta; -import com.moyz.adi.common.vo.SseAskParams; -import com.theokanning.openai.OpenAiApi; -import com.theokanning.openai.image.CreateImageEditRequest; -import com.theokanning.openai.image.CreateImageVariationRequest; -import com.theokanning.openai.image.ImageResult; -import com.theokanning.openai.service.OpenAiService; -import dev.langchain4j.data.image.Image; -import dev.langchain4j.memory.ChatMemory; -import dev.langchain4j.model.image.ImageModel; -import dev.langchain4j.model.openai.OpenAiImageModel; -import dev.langchain4j.model.openai.OpenAiStreamingChatModel; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.service.AiServices; -import dev.langchain4j.service.TokenStream; -import jakarta.annotation.Resource; -import lombok.extern.slf4j.Slf4j; -import okhttp3.OkHttpClient; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Service; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import retrofit2.Retrofit; - -import java.io.File; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.Proxy; -import java.time.Duration; -import java.time.temporal.ChronoUnit; -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; - -import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_RESP_FORMATS_URL; -import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES; -import static com.theokanning.openai.service.OpenAiService.defaultClient; -import static com.theokanning.openai.service.OpenAiService.defaultRetrofit; -import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_1024_x_1024; -import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_512_x_512; -import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2; - -@Slf4j -@Service -public class OpenAiHelper { - - @Value("${openai.proxy.enable:false}") - private boolean proxyEnable; - - @Value("${openai.proxy.host:0}") - private String proxyHost; - - @Value("${openai.proxy.http-port:0}") - private int proxyHttpPort; - - @Resource - private FileService fileService; - - @Resource - private ObjectMapper objectMapper; - - public String getSecretKey() { - String secretKey = SysConfigService.getSecretKey(); - User user = ThreadContext.getCurrentUser(); - if (null != user && StringUtils.isNotBlank(user.getSecretKey())) { - secretKey = user.getSecretKey(); - } - return secretKey; - } - - public OpenAiService getOpenAiService() { - String secretKey = getSecretKey(); - if (proxyEnable) { - Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); - OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS)) - .newBuilder() - .proxy(proxy) - .build(); - Retrofit retrofit = defaultRetrofit(client, objectMapper); - OpenAiApi api = retrofit.create(OpenAiApi.class); - return new OpenAiService(api); - } - return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS)); - } - - public IChatAssistant getChatAssistant(ChatMemory chatMemory) { - String secretKey = getSecretKey(); - OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel.builder(); - if (proxyEnable) { - Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); - builder.proxy(proxy); - } - builder.apiKey(secretKey).timeout(Duration.of(60, ChronoUnit.SECONDS)); - AiServices serviceBuilder = AiServices.builder(IChatAssistant.class) - .streamingChatLanguageModel(builder.build()); - if (null != chatMemory) { - serviceBuilder.chatMemory(chatMemory); - } - return serviceBuilder.build(); - } - - public ImageModel getImageModel(User user, String size) { - String secretKey = getSecretKey(); - if (proxyEnable) { - Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); - return OpenAiImageModel.builder() - .modelName(DALL_E_2) - .apiKey(secretKey) - .user(user.getUuid()) - .responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL) - .size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512)) - .logRequests(true) - .logResponses(true) - .withPersisting(false) - .maxRetries(2) - .proxy(proxy) - .build(); - } - return OpenAiImageModel.builder() - .modelName(DALL_E_2) - .apiKey(secretKey) - .user(user.getUuid()) - .responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL) - .size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512)) - .logRequests(true) - .logResponses(true) - .withPersisting(false) - .maxRetries(2) - .build(); - } - - /** - * Send http request to llm server - */ - public void sseAsk(SseAskParams params, TriConsumer consumer) { - IChatAssistant chatAssistant = getChatAssistant(params.getChatMemory()); - TokenStream tokenStream; - if (StringUtils.isNotBlank(params.getSystemMessage())) { - tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage()); - } else { - tokenStream = chatAssistant.chat(params.getUserMessage()); - } - tokenStream.onNext((content) -> { - log.info("get content:{}", content); - //加空格配合前端的fetchEventSource进行解析,见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133 - try { - params.getSseEmitter().send(" " + content); - } catch (IOException e) { - log.error("stream onNext error", e); - } - }) - .onComplete((response) -> { - log.info("返回数据结束了:{}", response); - String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", ""); - QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid); - AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", "")); - ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta); - String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", ""); - log.info("meta:" + meta); - try { - params.getSseEmitter().send(" [META]" + meta); - } catch (IOException e) { - log.error("stream onComplete error", e); - throw new RuntimeException(e); - } - // close eventSourceEmitter after tokens was calculated - params.getSseEmitter().complete(); - consumer.accept(response.content().text(), questionMeta, answerMeta); - }) - .onError((error) -> { - log.error("stream error", error); - try { - params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage())); - } catch (IOException e) { - log.error("sse error", e); - } - params.getSseEmitter().complete(); - }) - .start(); - } - - public List createImage(User user, AiImage aiImage) { - if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) { - throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR); - } - if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) { - throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR); - } - ImageModel imageModel = getImageModel(user, aiImage.getGenerateSize()); - try { - Response> response = imageModel.generate(aiImage.getPrompt(), aiImage.getGenerateNumber()); - log.info("createImage response:{}", response); - return response.content().stream().map(item -> item.url().toString()).collect(Collectors.toList()); - } catch (Exception e) { - log.error("create image error", e); - } - return Collections.emptyList(); - } - - public List editImage(User user, AiImage aiImage) { - File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage())); - File maskFile = null; - if (StringUtils.isNotBlank(aiImage.getMaskImage())) { - maskFile = new File(fileService.getImagePath(aiImage.getMaskImage())); - } - //如果不是RGBA类型的图片,先转成RGBA - File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage())); - OpenAiService service = getOpenAiService(); - CreateImageEditRequest request = new CreateImageEditRequest(); - request.setPrompt(aiImage.getPrompt()); - request.setN(aiImage.getGenerateNumber()); - request.setSize(aiImage.getGenerateSize()); - request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL); - request.setUser(user.getUuid()); - try { - ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile); - log.info("editImage response:{}", imageResult); - return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList()); - } catch (Exception e) { - log.error("edit image error", e); - } - return Collections.emptyList(); - } - - public List createImageVariation(User user, AiImage aiImage) { - File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage())); - OpenAiService service = getOpenAiService(); - CreateImageVariationRequest request = new CreateImageVariationRequest(); - request.setN(aiImage.getGenerateNumber()); - request.setSize(aiImage.getGenerateSize()); - request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL); - request.setUser(user.getUuid()); - try { - ImageResult imageResult = service.createImageVariation(request, imagePath); - log.info("createImageVariation response:{}", imageResult); - return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList()); - } catch (Exception e) { - log.error("image variation error", e); - } - return Collections.emptyList(); - } - -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java new file mode 100644 index 0000000..896ef84 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java @@ -0,0 +1,86 @@ +package com.moyz.adi.common.interfaces; + +import com.moyz.adi.common.entity.AiImage; +import com.moyz.adi.common.entity.User; +import com.moyz.adi.common.enums.ErrorEnum; +import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.util.JsonUtil; +import com.moyz.adi.common.util.LocalCache; +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.output.Response; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; + +import java.net.Proxy; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES; + +@Slf4j +public abstract class AbstractImageModelService { + + protected Proxy proxy; + + protected String modelName; + + protected T setting; + + @Value("${adi.proxy.enable:false}") + protected boolean proxyEnable; + + @Value("${adi.proxy.host:0}") + protected String proxyHost; + + @Value("${adi.proxy.http-port:0}") + protected int proxyHttpPort; + + protected ImageModel imageModel; + + public AbstractImageModelService(String modelName, String settingName, Class clazz, Proxy proxy){ + this.modelName = modelName; + this.proxy = proxy; + String st = LocalCache.CONFIGS.get(settingName); + setting = JsonUtil.fromJson(st, clazz); + } + + public ImageModel getImageModel(User user, String size) { + if (null != imageModel) { + return imageModel; + } + imageModel = buildImageModel(user, size); + return imageModel; + } + + /** + * 检测该service是否可用(不可用的情况通过是没有配置key) + * @return + */ + public abstract boolean isEnabled(); + + protected abstract ImageModel buildImageModel(User user, String size); + + public List createImage(User user, AiImage aiImage) { + if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) { + throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR); + } + if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) { + throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR); + } + ImageModel imageModel = getImageModel(user, aiImage.getGenerateSize()); + try { + Response> response = imageModel.generate(aiImage.getPrompt(), aiImage.getGenerateNumber()); + log.info("createImage response:{}", response); + return response.content().stream().map(item -> item.url().toString()).collect(Collectors.toList()); + } catch (Exception e) { + log.error("create image error", e); + } + return Collections.emptyList(); + } + + public abstract List editImage(User user, AiImage aiImage); + + public abstract List createImageVariation(User user, AiImage aiImage); +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java new file mode 100644 index 0000000..0173d5d --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java @@ -0,0 +1,128 @@ +package com.moyz.adi.common.interfaces; + +import com.moyz.adi.common.util.JsonUtil; +import com.moyz.adi.common.util.LocalCache; +import com.moyz.adi.common.vo.AnswerMeta; +import com.moyz.adi.common.vo.ChatMeta; +import com.moyz.adi.common.vo.QuestionMeta; +import com.moyz.adi.common.vo.SseAskParams; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.TokenStream; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.net.Proxy; +import java.util.UUID; + +@Slf4j +public abstract class AbstractLLMService { + + protected Proxy proxy; + + protected String modelName; + + protected T setting; + + protected StreamingChatLanguageModel streamingChatLanguageModel; + protected ChatLanguageModel chatLanguageModel; + + public AbstractLLMService(String modelName, String settingName, Class clazz, Proxy proxy){ + this.modelName = modelName; + this.proxy = proxy; + String st = LocalCache.CONFIGS.get(settingName); + setting = JsonUtil.fromJson(st, clazz); + } + + /** + * 检测该service是否可用(不可用的情况通常是没有配置key) + * + * @return + */ + public abstract boolean isEnabled(); + + public ChatLanguageModel getChatLLM() { + if (null != chatLanguageModel) { + return chatLanguageModel; + } + chatLanguageModel = buildChatLLM(); + return chatLanguageModel; + } + + public StreamingChatLanguageModel getStreamingChatLLM() { + if (null != streamingChatLanguageModel) { + return streamingChatLanguageModel; + } + streamingChatLanguageModel = buildStreamingChatLLM(); + return streamingChatLanguageModel; + } + + protected abstract ChatLanguageModel buildChatLLM(); + + protected abstract StreamingChatLanguageModel buildStreamingChatLLM(); + + public String chat(ChatMessage chatMessage) { + return getChatLLM().generate(chatMessage).content().text(); + } + + public void sseChat(SseAskParams params, TriConsumer consumer) { + + //create chat assistant + AiServices serviceBuilder = AiServices.builder(IChatAssistant.class) + .streamingChatLanguageModel(getStreamingChatLLM()); + if (null != params.getChatMemory()) { + serviceBuilder.chatMemory(params.getChatMemory()); + } + IChatAssistant chatAssistant = serviceBuilder.build(); + + TokenStream tokenStream; + if (StringUtils.isNotBlank(params.getSystemMessage())) { + tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage()); + } else { + tokenStream = chatAssistant.chat(params.getUserMessage()); + } + tokenStream.onNext((content) -> { + log.info("get content:{}", content); + //加空格配合前端的fetchEventSource进行解析,见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133 + try { + params.getSseEmitter().send(" " + content); + } catch (IOException e) { + log.error("stream onNext error", e); + } + }) + .onComplete((response) -> { + log.info("返回数据结束了:{}", response); + String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", ""); + QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid); + AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", "")); + ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta); + String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", ""); + log.info("meta:" + meta); + try { + params.getSseEmitter().send(" [META]" + meta); + } catch (IOException e) { + log.error("stream onComplete error", e); + throw new RuntimeException(e); + } + // close eventSourceEmitter after tokens was calculated + params.getSseEmitter().complete(); + consumer.accept(response.content().text(), questionMeta, answerMeta); + }) + .onError((error) -> { + log.error("stream error", error); + try { + params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage())); + } catch (IOException e) { + log.error("sse error", e); + } + params.getSseEmitter().complete(); + }) + .start(); + } + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/TriConsumer.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/TriConsumer.java new file mode 100644 index 0000000..9d3cb3c --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/TriConsumer.java @@ -0,0 +1,6 @@ +package com.moyz.adi.common.interfaces; + +@FunctionalInterface +public interface TriConsumer { + void accept(T t, U u, V v); +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/model/AnswerMeta.java b/adi-common/src/main/java/com/moyz/adi/common/model/AnswerMeta.java deleted file mode 100644 index 6b6f464..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/model/AnswerMeta.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.moyz.adi.common.model; - -import lombok.AllArgsConstructor; -import lombok.Data; - -@Data -@AllArgsConstructor -public class AnswerMeta { - private Integer tokens; - private String uuid; -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/model/ChatMeta.java b/adi-common/src/main/java/com/moyz/adi/common/model/ChatMeta.java deleted file mode 100644 index 63acf80..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/model/ChatMeta.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.moyz.adi.common.model; - -import lombok.AllArgsConstructor; -import lombok.Data; - -@Data -@AllArgsConstructor -public class ChatMeta { - private QuestionMeta question; - private AnswerMeta answer; -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/model/CostStat.java b/adi-common/src/main/java/com/moyz/adi/common/model/CostStat.java deleted file mode 100644 index 78dfb5b..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/model/CostStat.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.moyz.adi.common.model; - -import lombok.Data; - -@Data -public class CostStat { - private int day; - private int textRequestTimesByDay; - private int textTokenCostByDay; - private int imageGeneratedNumberByDay; - private int textTokenCostByMonth; - private int textRequestTimesByMonth; - private int imageGeneratedNumberByMonth; -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/model/QuestionMeta.java b/adi-common/src/main/java/com/moyz/adi/common/model/QuestionMeta.java deleted file mode 100644 index 199c35d..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/model/QuestionMeta.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.moyz.adi.common.model; - -import lombok.AllArgsConstructor; -import lombok.Data; - -@Data -@AllArgsConstructor -public class QuestionMeta { - private Integer tokens; - private String uuid; -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/model/RequestRateLimit.java b/adi-common/src/main/java/com/moyz/adi/common/model/RequestRateLimit.java deleted file mode 100644 index 6984e96..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/model/RequestRateLimit.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.moyz.adi.common.model; - -import lombok.Data; - -@Data -public class RequestRateLimit { - private int times; - private int minutes; - private int type; - - public static final int TYPE_TEXT = 1; - public static final int TYPE_IMAGE = 2; -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/AiImageService.java b/adi-common/src/main/java/com/moyz/adi/common/service/AiImageService.java index 1a20eaa..1eb2f2c 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/AiImageService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/AiImageService.java @@ -10,7 +10,7 @@ import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.exception.BaseException; -import com.moyz.adi.common.helper.OpenAiHelper; +import com.moyz.adi.common.helper.ImageModelContext; import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.mapper.AiImageMapper; @@ -46,10 +46,6 @@ public class AiImageService extends ServiceImpl { @Resource @Lazy private AiImageService _this; - - @Resource - private OpenAiHelper openAiHelper; - @Resource private QuotaHelper quotaHelper; @@ -173,13 +169,14 @@ public class AiImageService extends ServiceImpl { String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.IMAGE_RATE_LIMIT_CONFIG); + ImageModelContext modelContext = new ImageModelContext(); List images = new ArrayList<>(); if (aiImage.getInteractingMethod() == INTERACTING_METHOD_GENERATE_IMAGE) { - images = openAiHelper.createImage(user, aiImage); + images = modelContext.getModelService().createImage(user, aiImage); } else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_EDIT_IMAGE) { - images = openAiHelper.editImage(user, aiImage); + images = modelContext.getModelService().editImage(user, aiImage); } else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_VARIATION) { - images = openAiHelper.createImageVariation(user, aiImage); + images = modelContext.getModelService().createImageVariation(user, aiImage); } List imageUuids = new ArrayList(); images.forEach(imageUrl -> { diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/AiModelService.java b/adi-common/src/main/java/com/moyz/adi/common/service/AiModelService.java deleted file mode 100644 index 4ff35cf..0000000 --- a/adi-common/src/main/java/com/moyz/adi/common/service/AiModelService.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.moyz.adi.common.service; - -import com.moyz.adi.common.enums.AiModelStatus; -import com.moyz.adi.common.entity.AiModel; -import com.moyz.adi.common.helper.OpenAiHelper; -import com.moyz.adi.common.mapper.AiModelMapper; -import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import jakarta.annotation.PostConstruct; -import jakarta.annotation.Resource; -import org.springframework.stereotype.Service; - -import java.util.ArrayList; -import java.util.List; - -@Service -public class AiModelService extends ServiceImpl { - - public final static List AI_MODELS = new ArrayList<>(); - - @Resource - private OpenAiHelper openAiHelper; - - @PostConstruct - public void init() { - List aiModels = this.lambdaQuery().eq(AiModel::getModelStatus, AiModelStatus.ACTIVE).list(); - AI_MODELS.addAll(aiModels); - - //get models from openai -// List openaiModels = openAiHelper.getModels(); -// for (Model model : openaiModels) { -// AiModel aiModel = this.lambdaQuery().eq(AiModel::getName, model.getId()).one(); -// if (null == aiModel) { -// aiModel = new AiModel(); -// aiModel.setName(model.getId()); -// aiModel.setModelStatus(AiModelStatus.INACTIVE); -// baseMapper.insert(aiModel); -// } -// } - //refresh models cache -// aiModels = this.lambdaQuery().eq(AiModel::getModelStatus, AiModelStatus.ACTIVE).list(); -// AI_MODELS.clear(); -// AI_MODELS.addAll(aiModels); - } -} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java index 8f203cd..385b149 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java @@ -13,7 +13,7 @@ import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.exception.BaseException; -import com.moyz.adi.common.helper.OpenAiHelper; +import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.mapper.ConversationMessageMapper; @@ -59,9 +59,6 @@ public class ConversationMessageService extends ServiceImpl { + new LLMContext(askReq.getModelName()).getLLMService().sseChat(sseAskParams, (response, questionMeta, answerMeta) -> { try { - _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); + _this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta); } catch (Exception e) { log.error("error:", e); } finally { diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java new file mode 100644 index 0000000..23c2c40 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java @@ -0,0 +1,55 @@ +package com.moyz.adi.common.service; + +import com.moyz.adi.common.cosntant.AdiConstant; +import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.interfaces.AbstractLLMService; +import com.moyz.adi.common.vo.DashScopeSetting; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.dashscope.QwenChatModel; +import dev.langchain4j.model.dashscope.QwenStreamingChatModel; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + +import java.net.Proxy; + +import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET; + +/** + * 灵积模型服务(DashScope LLM service) + */ +@Slf4j +public class DashScopeLLMService extends AbstractLLMService { + + public DashScopeLLMService(String modelName, Proxy proxy) { + super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class, proxy); + } + + @Override + public boolean isEnabled() { + return StringUtils.isNotBlank(setting.getApiKey()); + } + + @Override + protected StreamingChatLanguageModel buildStreamingChatLLM() { + if (StringUtils.isBlank(setting.getApiKey())) { + throw new BaseException(B_LLM_SECRET_KEY_NOT_SET); + } + return QwenStreamingChatModel.builder() + .apiKey(setting.getApiKey()) + .modelName(modelName) + .build(); + } + + @Override + protected ChatLanguageModel buildChatLLM() { + if (StringUtils.isBlank(setting.getApiKey())) { + throw new BaseException(B_LLM_SECRET_KEY_NOT_SET); + } + return QwenChatModel.builder() + .apiKey(setting.getApiKey()) + .modelName(modelName) + .build(); + } + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java index de97cc2..2dce9cc 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java @@ -1,22 +1,48 @@ package com.moyz.adi.common.service; -import com.moyz.adi.common.helper.EmbeddingHelper; +import com.moyz.adi.common.helper.ImageModelContext; +import com.moyz.adi.common.helper.LLMContext; +import dev.langchain4j.model.dashscope.QwenModelName; +import dev.langchain4j.model.openai.OpenAiModelName; import jakarta.annotation.PostConstruct; import jakarta.annotation.Resource; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import java.net.InetSocketAddress; +import java.net.Proxy; + @Service public class Initializer { + @Value("${adi.proxy.enable:false}") + protected boolean proxyEnable; + + @Value("${adi.proxy.host:0}") + protected String proxyHost; + + @Value("${adi.proxy.http-port:0}") + protected int proxyHttpPort; + @Resource private SysConfigService sysConfigService; @Resource - private EmbeddingHelper embeddingHelper; + private RAGService ragService; @PostConstruct - public void init(){ + public void init() { sysConfigService.reload(); - embeddingHelper.init(); + + Proxy proxy = null; + if (proxyEnable) { + proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); + } + LLMContext.addLLMService(OpenAiModelName.GPT_3_5_TURBO, new OpenAiLLMService(OpenAiModelName.GPT_3_5_TURBO, proxy)); + LLMContext.addLLMService(QwenModelName.QWEN_MAX, new DashScopeLLMService(QwenModelName.QWEN_MAX, proxy)); + ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2, proxy)); + + + ragService.init(); } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseEmbeddingService.java b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseEmbeddingService.java index ea8cba4..cac8631 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseEmbeddingService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseEmbeddingService.java @@ -23,7 +23,13 @@ public class KnowledgeBaseEmbeddingService extends ServiceImpl { @Resource - private EmbeddingHelper embeddingHelper; + private RAGService ragService; @Resource private KnowledgeBaseEmbeddingService knowledgeBaseEmbeddingService; @@ -103,12 +102,20 @@ public class KnowledgeBaseItemService extends ServiceImpl { + + @Resource + private FileService fileService; + + @Resource + private ObjectMapper objectMapper; + + public OpenAiImageModelService(String modelName, Proxy proxy) { + super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy); + } + + @Override + public boolean isEnabled() { + return StringUtils.isNotBlank(setting.getSecretKey()); + } + + @Override + public ImageModel buildImageModel(User user, String size) { + if (StringUtils.isBlank(setting.getSecretKey())) { + throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); + } + OpenAiImageModel.OpenAiImageModelBuilder builder = OpenAiImageModel.builder() + .modelName(modelName) + .apiKey(setting.getSecretKey()) + .user(user.getUuid()) + .responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL) + .size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512)) + .logRequests(true) + .logResponses(true) + .withPersisting(false) + .maxRetries(2); + if (null != proxy) { + builder.proxy(proxy); + } + return builder.build(); + } + + @Override + public List editImage(User user, AiImage aiImage) { + File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage())); + File maskFile = null; + if (StringUtils.isNotBlank(aiImage.getMaskImage())) { + maskFile = new File(fileService.getImagePath(aiImage.getMaskImage())); + } + //如果不是RGBA类型的图片,先转成RGBA + File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage())); + OpenAiService service = getOpenAiService(); + CreateImageEditRequest request = new CreateImageEditRequest(); + request.setPrompt(aiImage.getPrompt()); + request.setN(aiImage.getGenerateNumber()); + request.setSize(aiImage.getGenerateSize()); + request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL); + request.setUser(user.getUuid()); + try { + ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile); + log.info("editImage response:{}", imageResult); + return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList()); + } catch (Exception e) { + log.error("edit image error", e); + } + return Collections.emptyList(); + } + + @Override + public List createImageVariation(User user, AiImage aiImage) { + File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage())); + OpenAiService service = getOpenAiService(); + CreateImageVariationRequest request = new CreateImageVariationRequest(); + request.setN(aiImage.getGenerateNumber()); + request.setSize(aiImage.getGenerateSize()); + request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL); + request.setUser(user.getUuid()); + try { + ImageResult imageResult = service.createImageVariation(request, imagePath); + log.info("createImageVariation response:{}", imageResult); + return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList()); + } catch (Exception e) { + log.error("image variation error", e); + } + return Collections.emptyList(); + } + + public OpenAiService getOpenAiService() { + String secretKey = setting.getSecretKey(); + if (proxyEnable) { + Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); + OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS)) + .newBuilder() + .proxy(proxy) + .build(); + Retrofit retrofit = defaultRetrofit(client, objectMapper); + OpenAiApi api = retrofit.create(OpenAiApi.class); + return new OpenAiService(api); + } + return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS)); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiLLMService.java new file mode 100644 index 0000000..1a5da32 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiLLMService.java @@ -0,0 +1,65 @@ +package com.moyz.adi.common.service; + +import com.moyz.adi.common.cosntant.AdiConstant; +import com.moyz.adi.common.enums.ErrorEnum; +import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.interfaces.AbstractLLMService; +import com.moyz.adi.common.vo.OpenAiSetting; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + +import java.net.Proxy; +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +/** + * OpenAi LLM service + */ +@Slf4j +@Accessors(chain = true) +public class OpenAiLLMService extends AbstractLLMService { + + public OpenAiLLMService(String modelName, Proxy proxy) { + super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy); + } + + @Override + public boolean isEnabled() { + return StringUtils.isNotBlank(setting.getSecretKey()); + } + + @Override + protected StreamingChatLanguageModel buildStreamingChatLLM() { + if (StringUtils.isBlank(setting.getSecretKey())) { + throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); + } + OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel + .builder() + .modelName(modelName) + .apiKey(setting.getSecretKey()) + .timeout(Duration.of(60, ChronoUnit.SECONDS)); + if (null != proxy) { + builder.proxy(proxy); + } + return builder.build(); + } + + @Override + protected ChatLanguageModel buildChatLLM() { + if (StringUtils.isBlank(setting.getSecretKey())) { + throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET); + } + OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(setting.getSecretKey()); + if (null != proxy) { + builder.proxy(proxy); + } + return builder.build(); + } + + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/EmbeddingHelper.java b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java similarity index 76% rename from adi-common/src/main/java/com/moyz/adi/common/helper/EmbeddingHelper.java rename to adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java index df8e273..b739480 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/helper/EmbeddingHelper.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java @@ -1,30 +1,25 @@ -package com.moyz.adi.common.helper; +package com.moyz.adi.common.service; +import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; +import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.splitter.DocumentSplitters; import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; -import jakarta.annotation.PostConstruct; -import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; -import java.net.InetSocketAddress; -import java.net.Proxy; import java.util.List; import java.util.Map; import java.util.regex.Matcher; @@ -35,8 +30,7 @@ import static java.util.stream.Collectors.joining; @Slf4j @Service -public class EmbeddingHelper { - +public class RAGService { @Value("${spring.datasource.url}") private String dataBaseUrl; @@ -45,32 +39,16 @@ public class EmbeddingHelper { @Value("${spring.datasource.password}") private String dataBasePassword; - - @Value("${openai.proxy.enable:false}") - private boolean proxyEnable; - - @Value("${openai.proxy.host:0}") - private String proxyHost; - - @Value("${openai.proxy.http-port:0}") - private int proxyHttpPort; - private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}"); - @Resource - private OpenAiHelper openAiHelper; - private EmbeddingModel embeddingModel; private EmbeddingStore embeddingStore; - private ChatLanguageModel chatLanguageModel; - public void init() { log.info("initEmbeddingModel"); embeddingModel = new AllMiniLmL6V2EmbeddingModel(); embeddingStore = initEmbeddingStore(); - chatLanguageModel = initChatLanguageModel(); } private EmbeddingStore initEmbeddingStore() { @@ -107,16 +85,7 @@ public class EmbeddingHelper { return embeddingStore; } - private ChatLanguageModel initChatLanguageModel() { - OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(openAiHelper.getSecretKey()); - if (proxyEnable) { - Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); - builder.proxy(proxy); - } - return builder.build(); - } - - public EmbeddingStoreIngestor getEmbeddingStoreIngestor() { + private EmbeddingStoreIngestor getEmbeddingStoreIngestor() { DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO)); EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder() .documentSplitter(documentSplitter) @@ -126,7 +95,24 @@ public class EmbeddingHelper { return embeddingStoreIngestor; } - public String findAnswer(String kbUuid, String question) { + /** + * 对文档切块并向量化 + * + * @param document 知识库文档 + */ + public void ingest(Document document) { + getEmbeddingStoreIngestor().ingest(document); + } + + /** + * 召回并搜索 + * + * @param kbUuid 知识库uuid + * @param question 用户的问题 + * @param modelName LLM model name + * @return + */ + public String findAnswer(String kbUuid, String question, String modelName) { // Embed the question Embedding questionEmbedding = embeddingModel.embed(question).content(); @@ -147,10 +133,6 @@ public class EmbeddingHelper { } Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); - AiMessage aiMessage = chatLanguageModel.generate(prompt.toUserMessage()).content(); - - // See an answer from the model - return aiMessage.text(); + return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); } - } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/SysConfigService.java b/adi-common/src/main/java/com/moyz/adi/common/service/SysConfigService.java index 21f58ca..dc31273 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/SysConfigService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/SysConfigService.java @@ -57,10 +57,6 @@ public class SysConfigService extends ServiceImpl { return Integer.parseInt(maxNum); } - public static String getSecretKey() { - return LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.SECRET_KEY); - } - public static String getByKey(String key) { return LocalCache.CONFIGS.get(key); } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/UserService.java b/adi-common/src/main/java/com/moyz/adi/common/service/UserService.java index a925635..f16a061 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/UserService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/UserService.java @@ -98,7 +98,7 @@ public class UserService extends ServiceImpl { .eq(User::getEmail, email) .one(); if (null != user && user.getUserStatus() == UserStatusEnum.NORMAL) { - throw new BaseException(A_REGISTER_USER_EXIST); + throw new BaseException(A_USER_EXIST); } if (null != user) { sendActiveEmail(email); @@ -112,7 +112,7 @@ public class UserService extends ServiceImpl { //创建用户 User newOne = new User(); - newOne.setName(StringUtils.substringBetween(email, "@")); + newOne.setName(StringUtils.substringBefore(email, "@")); newOne.setUuid(UUID.randomUUID().toString().replace("-", "")); newOne.setEmail(email); newOne.setPassword(hashed); diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/DashScopeSetting.java b/adi-common/src/main/java/com/moyz/adi/common/vo/DashScopeSetting.java new file mode 100644 index 0000000..792e2df --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/DashScopeSetting.java @@ -0,0 +1,11 @@ +package com.moyz.adi.common.vo; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +@Data +public class DashScopeSetting { + + @JsonProperty("api_key") + private String apiKey; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/ImageModelInfo.java b/adi-common/src/main/java/com/moyz/adi/common/vo/ImageModelInfo.java new file mode 100644 index 0000000..297077b --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/ImageModelInfo.java @@ -0,0 +1,12 @@ +package com.moyz.adi.common.vo; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.moyz.adi.common.interfaces.AbstractImageModelService; +import lombok.Data; + +@Data +public class ImageModelInfo extends ModelInfo { + + @JsonIgnore + private AbstractImageModelService modelService; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/LLMModelInfo.java b/adi-common/src/main/java/com/moyz/adi/common/vo/LLMModelInfo.java new file mode 100644 index 0000000..1226fe3 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/LLMModelInfo.java @@ -0,0 +1,12 @@ +package com.moyz.adi.common.vo; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.moyz.adi.common.interfaces.AbstractLLMService; +import lombok.Data; + +@Data +public class LLMModelInfo extends ModelInfo { + + @JsonIgnore + private AbstractLLMService llmService; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/ModelInfo.java b/adi-common/src/main/java/com/moyz/adi/common/vo/ModelInfo.java new file mode 100644 index 0000000..cce403d --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/ModelInfo.java @@ -0,0 +1,9 @@ +package com.moyz.adi.common.vo; + +import lombok.Data; + +@Data +public class ModelInfo { + private String modelName; + private Boolean enable; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/OpenAiSetting.java b/adi-common/src/main/java/com/moyz/adi/common/vo/OpenAiSetting.java new file mode 100644 index 0000000..a51828c --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/OpenAiSetting.java @@ -0,0 +1,11 @@ +package com.moyz.adi.common.vo; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +@Data +public class OpenAiSetting { + + @JsonProperty("secret_key") + private String secretKey; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java index 2ae33f9..9c79160 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/SseAskParams.java @@ -1,7 +1,6 @@ package com.moyz.adi.common.vo; import com.moyz.adi.common.entity.User; -import com.moyz.adi.common.util.TriConsumer; import dev.langchain4j.memory.ChatMemory; import lombok.Data; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; diff --git a/adi-common/src/main/resources/mapper/KnowledgeBaseEmbedding.xml b/adi-common/src/main/resources/mapper/KnowledgeBaseEmbedding.xml index 03ed9f5..fcfc866 100644 --- a/adi-common/src/main/resources/mapper/KnowledgeBaseEmbedding.xml +++ b/adi-common/src/main/resources/mapper/KnowledgeBaseEmbedding.xml @@ -3,10 +3,23 @@ + + delete from adi_knowledge_base_embedding where embedding_id in + + + #{id} + + + + - delete from adi_knowledge_base_embedding where metadata->>'kb_item_uuid' = #{kbItemUuid} + delete + from adi_knowledge_base_embedding + where metadata ->> 'kb_item_uuid' = #{kbItemUuid} diff --git a/docs/create.sql b/docs/create.sql index 2d41254..8169342 100644 --- a/docs/create.sql +++ b/docs/create.sql @@ -201,7 +201,7 @@ CREATE TABLE public.adi_sys_config ( id bigserial primary key, name character varying(100) DEFAULT ''::character varying NOT NULL, - value character varying(100) DEFAULT ''::character varying NOT NULL, + value character varying(1000) DEFAULT ''::character varying NOT NULL, create_time timestamp DEFAULT localtimestamp NOT NULL, update_time timestamp DEFAULT localtimestamp NOT NULL, is_deleted boolean DEFAULT false NOT NULL @@ -382,7 +382,9 @@ CREATE TRIGGER trigger_user_day_cost_update_time EXECUTE PROCEDURE update_modified_column(); INSERT INTO adi_sys_config (name, value) -VALUES ('secret_key', ''); +VALUES ('openai_setting', '{"secret_key":""}'); +INSERT INTO adi_sys_config (name, value) +VALUES ('dashscope_setting', '{"api_key":""}'); INSERT INTO adi_sys_config (name, value) VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}'); INSERT INTO adi_sys_config (name, value) diff --git a/pom.xml b/pom.xml index 2005208..16526df 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 17 17 UTF-8 - 0.25.0 + 0.27.1 @@ -160,6 +160,11 @@ langchain4j-document-parser-apache-poi ${langchain4j.version} + + dev.langchain4j + langchain4j-dashscope + ${langchain4j.version} + org.springframework.boot spring-boot-starter-test