增加支持通义千问

This commit is contained in:
moyangzhan 2024-02-27 23:55:53 +08:00
parent 0a2f2c1d70
commit ae26772ef5
48 changed files with 816 additions and 472 deletions

View File

@ -10,10 +10,8 @@
* 图片生成(文生图、修图、图生图) * 图片生成(文生图、修图、图生图)
* 提示词 * 提示词
* 额度控制 * 额度控制
* 自定义openai secret_key
* 基于大模型的知识库RAG * 基于大模型的知识库RAG
* 多模型随意切换
![1691585301627](image/README/1691585301627.png "登录注册")
**AI聊天** **AI聊天**
![1691583184761](image/README/1691583184761.png) ![1691583184761](image/README/1691583184761.png)
@ -31,14 +29,18 @@
![kb03](image/README/kb03.png) ![kb03](image/README/kb03.png)
体验网址:[http://www.aideepin.com](http://www.aideepin.com/) ### 体验网址
[http://www.aideepin.com](http://www.aideepin.com/)
接入的模型ChatGPT 3.5DALL-E 2 ### 接入的模型:
* ChatGPT 3.5
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) * 通义千问
* DALL-E 2
### 技术 ### 技术
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
后端: 后端:
jdk17 jdk17
@ -61,10 +63,11 @@ vue3+typescript+pnpm
* 创建数据库aideepin * 创建数据库aideepin
* 执行docs/create.sql * 执行docs/create.sql
* 填充openai的secret\_key * 填充openai的secretKey 或者 灵积模型的apiKey
```plaintext ```plaintext
update adi_sys_config set value = 'my_chatgpt_secret_key' where name = 'secret_key' update adi_sys_config set value = '{"secret_key":"my_openai_secret_key"}' where name = 'openai_setting';
update adi_sys_config set value = '{"api_key":"my_dashcope_api_key"}' where name = 'dashscope_setting';
``` ```
* 修改配置文件 * 修改配置文件

View File

@ -25,17 +25,14 @@ logging:
file: file:
path: D:/data/logs path: D:/data/logs
openai: adi:
frontend-url: http://localhost:1002
backend-url: http://localhost:1002/api
proxy: proxy:
enable: true enable: true
host: 127.0.0.1 host: 127.0.0.1
http-port: 1087 http-port: 1087
adi:
frontend-url: http://localhost:1002
backend-url: http://localhost:1002/api
local: local:
files: D:/data/files/ files: D:/data/files/
images: D:/data/images/ images: D:/data/images/

View File

@ -53,6 +53,10 @@ logging:
adi: adi:
frontend-url: http://www.aideepin.com frontend-url: http://www.aideepin.com
backend-url: http://www.aideepin.com/api backend-url: http://www.aideepin.com/api
proxy:
enable: false
host: 127.0.0.1
http-port: 1087
local: local:
files: /data/files/ files: /data/files/

View File

@ -39,6 +39,12 @@ public class KnowledgeBaseItemController {
.one(); .one();
} }
/**
* 知识点向量化
*
* @param uuid 知识点uuid
* @return
*/
@PostMapping("/embedding/{uuid}") @PostMapping("/embedding/{uuid}")
public boolean embedding(@PathVariable String uuid) { public boolean embedding(@PathVariable String uuid) {
return knowledgeBaseItemService.checkAndEmbedding(uuid); return knowledgeBaseItemService.checkAndEmbedding(uuid);

View File

@ -25,7 +25,7 @@ public class KnowledgeBaseQAController {
@PostMapping("/ask/{kbUuid}") @PostMapping("/ask/{kbUuid}")
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) { public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion()); return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion(), req.getModelName());
} }
@GetMapping("/record/search") @GetMapping("/record/search")

View File

@ -0,0 +1,29 @@
package com.moyz.adi.chat.controller;
import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.vo.ImageModelInfo;
import com.moyz.adi.common.vo.LLMModelInfo;
import io.swagger.v3.oas.annotations.Operation;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import java.util.stream.Collectors;
@RestController
@RequestMapping("/model")
public class ModelController {
@Operation(summary = "支持的大语言模型列表")
@GetMapping(value = "/llms")
public List<LLMModelInfo> llms() {
return LLMContext.NAME_TO_MODEL.values().stream().collect(Collectors.toList());
}
@Operation(summary = "支持的图片模型列表")
@GetMapping(value = "/imageModels")
public List<ImageModelInfo> imageModels() {
return ImageModelContext.NAME_TO_MODEL.values().stream().collect(Collectors.toList());
}
}

View File

@ -65,7 +65,8 @@ public class AdiConstant {
} }
public static class SysConfigKey { public static class SysConfigKey {
public static final String SECRET_KEY = "secret_key"; public static final String OPENAI_SETTING = "openai_setting";
public static final String DASHSCOPE_SETTING = "dashscope_setting";
public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit"; public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit";
public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit"; public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit";
public static final String CONVERSATION_MAX_NUM = "conversation_max_num"; public static final String CONVERSATION_MAX_NUM = "conversation_max_num";

View File

@ -21,4 +21,6 @@ public class AskReq {
* If not empty, it means will request AI with the exist prompt, param {@code prompt} is ignored * If not empty, it means will request AI with the exist prompt, param {@code prompt} is ignored
*/ */
private String regenerateQuestionUuid; private String regenerateQuestionUuid;
private String modelName;
} }

View File

@ -19,4 +19,6 @@ public class EditImageReq {
@Min(1) @Min(1)
@Max(10) @Max(10)
private int number; private int number;
private String modelName;
} }

View File

@ -14,4 +14,6 @@ public class GenerateImageReq {
@Min(1) @Min(1)
@Max(10) @Max(10)
private int number; private int number;
private String modelName;
} }

View File

@ -10,4 +10,6 @@ public class QAReq {
@NotBlank @NotBlank
private String question; private String question;
private String modelName;
} }

View File

@ -15,4 +15,6 @@ public class VariationImageReq {
@Min(1) @Min(1)
@Max(10) @Max(10)
private int number; private int number;
private String modelName;
} }

View File

@ -1,6 +1,8 @@
package com.moyz.adi.common.entity; package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.pgvector.PGvector; import com.pgvector.PGvector;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
@ -9,10 +11,10 @@ import lombok.Data;
@Data @Data
@TableName("adi_knowledge_base_embedding") @TableName("adi_knowledge_base_embedding")
@Schema(title = "知识库-嵌入实体", description = "知识库嵌入表") @Schema(title = "知识库-嵌入实体", description = "知识库嵌入表")
public class KnowledgeBaseEmbedding extends BaseEntity { public class KnowledgeBaseEmbedding{
@Schema(title = "embedding uuid") @Schema(title = "embedding_id")
@TableField("embedding") @TableId(value = "embedding_id", type = IdType.AUTO)
private String embeddingId; private String embeddingId;
@Schema(title = "embedding") @Schema(title = "embedding")

View File

@ -14,7 +14,7 @@ public enum ErrorEnum {
A_IMAGE_SIZE_ERROR("A0010", "图片尺寸不对"), A_IMAGE_SIZE_ERROR("A0010", "图片尺寸不对"),
A_FILE_NOT_EXIST("A0011", "文件不存在"), A_FILE_NOT_EXIST("A0011", "文件不存在"),
A_DRAWING("A0012", "作图还未完成"), A_DRAWING("A0012", "作图还未完成"),
A_REGISTER_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"), A_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"),
A_FIND_PASSWORD_CODE_ERROR("A0014", "重置码已过期或不存在"), A_FIND_PASSWORD_CODE_ERROR("A0014", "重置码已过期或不存在"),
A_USER_WAIT_CONFIRM("A0015", "用户未激活"), A_USER_WAIT_CONFIRM("A0015", "用户未激活"),
A_USER_NOT_AUTH("A0016", "用户无权限"), A_USER_NOT_AUTH("A0016", "用户无权限"),
@ -29,7 +29,8 @@ public enum ErrorEnum {
B_FIND_IMAGE_404("B0005", "无法找到图片"), B_FIND_IMAGE_404("B0005", "无法找到图片"),
B_DAILY_QUOTA_USED("B0006", "今天额度已经用完"), B_DAILY_QUOTA_USED("B0006", "今天额度已经用完"),
B_MONTHLY_QUOTA_USED("B0007", "当月额度已经用完"), B_MONTHLY_QUOTA_USED("B0007", "当月额度已经用完"),
B_LLM_NOT_SUPPORT("B0008", "LLM不支持该功能"),
B_LLM_SECRET_KEY_NOT_SET("B0009", "LLM的secret key没设置"),
B_MESSAGE_NOT_FOUND("B0008", "消息不存在"); B_MESSAGE_NOT_FOUND("B0008", "消息不存在");
private String code; private String code;

View File

@ -29,7 +29,8 @@ import static org.springframework.http.HttpHeaders.AUTHORIZATION;
public class TokenFilter extends OncePerRequestFilter { public class TokenFilter extends OncePerRequestFilter {
public static final String[] EXCLUDE_API = { public static final String[] EXCLUDE_API = {
"/auth/" "/auth/",
"/model/"
}; };
@Resource @Resource

View File

@ -0,0 +1,49 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.vo.ImageModelInfo;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.Map;
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
/**
* image model service上下文类策略模式
*/
@Slf4j
public class ImageModelContext {
/**
* AI图片模型
*/
public static final Map<String, ImageModelInfo> NAME_TO_MODEL = new HashMap<>();
private AbstractImageModelService modelService;
public ImageModelContext() {
modelService = NAME_TO_MODEL.get(DALL_E_2).getModelService();
}
public ImageModelContext(String modelName) {
if (null == NAME_TO_MODEL.get(modelName)) {
log.warn("︿︿︿ Can not find {}, use the default model DALL_E_2 ︿︿︿", modelName);
modelService = NAME_TO_MODEL.get(DALL_E_2).getModelService();
} else {
modelService = NAME_TO_MODEL.get(modelName).getModelService();
}
}
public static void addImageModelService(String modelName, AbstractImageModelService modelService) {
ImageModelInfo imageModelInfo = new ImageModelInfo();
imageModelInfo.setModelService(modelService);
imageModelInfo.setModelName(modelName);
imageModelInfo.setEnable(modelService.isEnabled());
NAME_TO_MODEL.put(modelName, imageModelInfo);
}
public AbstractImageModelService getModelService() {
return modelService;
}
}

View File

@ -0,0 +1,44 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.LLMModelInfo;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.Map;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
/**
* llmService上下文类策略模式
*/
@Slf4j
public class LLMContext {
public static final Map<String, LLMModelInfo> NAME_TO_MODEL = new HashMap<>();
private AbstractLLMService llmService;
public LLMContext() {
llmService = NAME_TO_MODEL.get(GPT_3_5_TURBO).getLlmService();
}
public LLMContext(String modelName) {
if (null == NAME_TO_MODEL.get(modelName)) {
log.warn("︿︿︿ Can not find {}, use the default model GPT_3_5_TURBO ︿︿︿", modelName);
llmService = NAME_TO_MODEL.get(GPT_3_5_TURBO).getLlmService();
} else {
llmService = NAME_TO_MODEL.get(modelName).getLlmService();
}
}
public static void addLLMService(String modelName, AbstractLLMService llmService) {
LLMModelInfo llmModelInfo = new LLMModelInfo();
llmModelInfo.setModelName(modelName);
llmModelInfo.setEnable(llmService.isEnabled());
llmModelInfo.setLlmService(llmService);
NAME_TO_MODEL.put(modelName, llmModelInfo);
}
public AbstractLLMService getLLMService() {
return llmService;
}
}

View File

@ -1,260 +0,0 @@
package com.moyz.adi.common.helper;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.IChatAssistant;
import com.moyz.adi.common.service.FileService;
import com.moyz.adi.common.service.SysConfigService;
import com.moyz.adi.common.util.ImageUtil;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.TriConsumer;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.ChatMeta;
import com.moyz.adi.common.vo.QuestionMeta;
import com.moyz.adi.common.vo.SseAskParams;
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.image.CreateImageEditRequest;
import com.theokanning.openai.image.CreateImageVariationRequest;
import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.service.OpenAiService;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.openai.OpenAiImageModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import retrofit2.Retrofit;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_RESP_FORMATS_URL;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES;
import static com.theokanning.openai.service.OpenAiService.defaultClient;
import static com.theokanning.openai.service.OpenAiService.defaultRetrofit;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_1024_x_1024;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_512_x_512;
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
@Slf4j
@Service
public class OpenAiHelper {
@Value("${openai.proxy.enable:false}")
private boolean proxyEnable;
@Value("${openai.proxy.host:0}")
private String proxyHost;
@Value("${openai.proxy.http-port:0}")
private int proxyHttpPort;
@Resource
private FileService fileService;
@Resource
private ObjectMapper objectMapper;
public String getSecretKey() {
String secretKey = SysConfigService.getSecretKey();
User user = ThreadContext.getCurrentUser();
if (null != user && StringUtils.isNotBlank(user.getSecretKey())) {
secretKey = user.getSecretKey();
}
return secretKey;
}
public OpenAiService getOpenAiService() {
String secretKey = getSecretKey();
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS))
.newBuilder()
.proxy(proxy)
.build();
Retrofit retrofit = defaultRetrofit(client, objectMapper);
OpenAiApi api = retrofit.create(OpenAiApi.class);
return new OpenAiService(api);
}
return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS));
}
public IChatAssistant getChatAssistant(ChatMemory chatMemory) {
String secretKey = getSecretKey();
OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel.builder();
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
builder.proxy(proxy);
}
builder.apiKey(secretKey).timeout(Duration.of(60, ChronoUnit.SECONDS));
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(builder.build());
if (null != chatMemory) {
serviceBuilder.chatMemory(chatMemory);
}
return serviceBuilder.build();
}
public ImageModel getImageModel(User user, String size) {
String secretKey = getSecretKey();
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
return OpenAiImageModel.builder()
.modelName(DALL_E_2)
.apiKey(secretKey)
.user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)
.size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512))
.logRequests(true)
.logResponses(true)
.withPersisting(false)
.maxRetries(2)
.proxy(proxy)
.build();
}
return OpenAiImageModel.builder()
.modelName(DALL_E_2)
.apiKey(secretKey)
.user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)
.size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512))
.logRequests(true)
.logResponses(true)
.withPersisting(false)
.maxRetries(2)
.build();
}
/**
* Send http request to llm server
*/
public void sseAsk(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) {
IChatAssistant chatAssistant = getChatAssistant(params.getChatMemory());
TokenStream tokenStream;
if (StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage());
} else {
tokenStream = chatAssistant.chat(params.getUserMessage());
}
tokenStream.onNext((content) -> {
log.info("get content:{}", content);
//加空格配合前端的fetchEventSource进行解析见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133
try {
params.getSseEmitter().send(" " + content);
} catch (IOException e) {
log.error("stream onNext error", e);
}
})
.onComplete((response) -> {
log.info("返回数据结束了:{}", response);
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid);
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
log.info("meta:" + meta);
try {
params.getSseEmitter().send(" [META]" + meta);
} catch (IOException e) {
log.error("stream onComplete error", e);
throw new RuntimeException(e);
}
// close eventSourceEmitter after tokens was calculated
params.getSseEmitter().complete();
consumer.accept(response.content().text(), questionMeta, answerMeta);
})
.onError((error) -> {
log.error("stream error", error);
try {
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage()));
} catch (IOException e) {
log.error("sse error", e);
}
params.getSseEmitter().complete();
})
.start();
}
public List<String> createImage(User user, AiImage aiImage) {
if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) {
throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR);
}
if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) {
throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR);
}
ImageModel imageModel = getImageModel(user, aiImage.getGenerateSize());
try {
Response<List<Image>> response = imageModel.generate(aiImage.getPrompt(), aiImage.getGenerateNumber());
log.info("createImage response:{}", response);
return response.content().stream().map(item -> item.url().toString()).collect(Collectors.toList());
} catch (Exception e) {
log.error("create image error", e);
}
return Collections.emptyList();
}
public List<String> editImage(User user, AiImage aiImage) {
File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage()));
File maskFile = null;
if (StringUtils.isNotBlank(aiImage.getMaskImage())) {
maskFile = new File(fileService.getImagePath(aiImage.getMaskImage()));
}
//如果不是RGBA类型的图片先转成RGBA
File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService();
CreateImageEditRequest request = new CreateImageEditRequest();
request.setPrompt(aiImage.getPrompt());
request.setN(aiImage.getGenerateNumber());
request.setSize(aiImage.getGenerateSize());
request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL);
request.setUser(user.getUuid());
try {
ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile);
log.info("editImage response:{}", imageResult);
return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) {
log.error("edit image error", e);
}
return Collections.emptyList();
}
public List<String> createImageVariation(User user, AiImage aiImage) {
File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService();
CreateImageVariationRequest request = new CreateImageVariationRequest();
request.setN(aiImage.getGenerateNumber());
request.setSize(aiImage.getGenerateSize());
request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL);
request.setUser(user.getUuid());
try {
ImageResult imageResult = service.createImageVariation(request, imagePath);
log.info("createImageVariation response:{}", imageResult);
return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) {
log.error("image variation error", e);
}
return Collections.emptyList();
}
}

View File

@ -0,0 +1,86 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import java.net.Proxy;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES;
@Slf4j
public abstract class AbstractImageModelService<T> {
protected Proxy proxy;
protected String modelName;
protected T setting;
@Value("${adi.proxy.enable:false}")
protected boolean proxyEnable;
@Value("${adi.proxy.host:0}")
protected String proxyHost;
@Value("${adi.proxy.http-port:0}")
protected int proxyHttpPort;
protected ImageModel imageModel;
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz, Proxy proxy){
this.modelName = modelName;
this.proxy = proxy;
String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz);
}
public ImageModel getImageModel(User user, String size) {
if (null != imageModel) {
return imageModel;
}
imageModel = buildImageModel(user, size);
return imageModel;
}
/**
* 检测该service是否可用不可用的情况通过是没有配置key
* @return
*/
public abstract boolean isEnabled();
protected abstract ImageModel buildImageModel(User user, String size);
public List<String> createImage(User user, AiImage aiImage) {
if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) {
throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR);
}
if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) {
throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR);
}
ImageModel imageModel = getImageModel(user, aiImage.getGenerateSize());
try {
Response<List<Image>> response = imageModel.generate(aiImage.getPrompt(), aiImage.getGenerateNumber());
log.info("createImage response:{}", response);
return response.content().stream().map(item -> item.url().toString()).collect(Collectors.toList());
} catch (Exception e) {
log.error("create image error", e);
}
return Collections.emptyList();
}
public abstract List<String> editImage(User user, AiImage aiImage);
public abstract List<String> createImageVariation(User user, AiImage aiImage);
}

View File

@ -0,0 +1,128 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.ChatMeta;
import com.moyz.adi.common.vo.QuestionMeta;
import com.moyz.adi.common.vo.SseAskParams;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.net.Proxy;
import java.util.UUID;
@Slf4j
public abstract class AbstractLLMService<T> {
protected Proxy proxy;
protected String modelName;
protected T setting;
protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy){
this.modelName = modelName;
this.proxy = proxy;
String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz);
}
/**
* 检测该service是否可用不可用的情况通常是没有配置key
*
* @return
*/
public abstract boolean isEnabled();
public ChatLanguageModel getChatLLM() {
if (null != chatLanguageModel) {
return chatLanguageModel;
}
chatLanguageModel = buildChatLLM();
return chatLanguageModel;
}
public StreamingChatLanguageModel getStreamingChatLLM() {
if (null != streamingChatLanguageModel) {
return streamingChatLanguageModel;
}
streamingChatLanguageModel = buildStreamingChatLLM();
return streamingChatLanguageModel;
}
protected abstract ChatLanguageModel buildChatLLM();
protected abstract StreamingChatLanguageModel buildStreamingChatLLM();
public String chat(ChatMessage chatMessage) {
return getChatLLM().generate(chatMessage).content().text();
}
public void sseChat(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) {
//create chat assistant
AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
.streamingChatLanguageModel(getStreamingChatLLM());
if (null != params.getChatMemory()) {
serviceBuilder.chatMemory(params.getChatMemory());
}
IChatAssistant chatAssistant = serviceBuilder.build();
TokenStream tokenStream;
if (StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage());
} else {
tokenStream = chatAssistant.chat(params.getUserMessage());
}
tokenStream.onNext((content) -> {
log.info("get content:{}", content);
//加空格配合前端的fetchEventSource进行解析见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133
try {
params.getSseEmitter().send(" " + content);
} catch (IOException e) {
log.error("stream onNext error", e);
}
})
.onComplete((response) -> {
log.info("返回数据结束了:{}", response);
String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid);
AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
log.info("meta:" + meta);
try {
params.getSseEmitter().send(" [META]" + meta);
} catch (IOException e) {
log.error("stream onComplete error", e);
throw new RuntimeException(e);
}
// close eventSourceEmitter after tokens was calculated
params.getSseEmitter().complete();
consumer.accept(response.content().text(), questionMeta, answerMeta);
})
.onError((error) -> {
log.error("stream error", error);
try {
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage()));
} catch (IOException e) {
log.error("sse error", e);
}
params.getSseEmitter().complete();
})
.start();
}
}

View File

@ -0,0 +1,6 @@
package com.moyz.adi.common.interfaces;
@FunctionalInterface
public interface TriConsumer<T, U, V> {
void accept(T t, U u, V v);
}

View File

@ -1,11 +0,0 @@
package com.moyz.adi.common.model;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class AnswerMeta {
private Integer tokens;
private String uuid;
}

View File

@ -1,11 +0,0 @@
package com.moyz.adi.common.model;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ChatMeta {
private QuestionMeta question;
private AnswerMeta answer;
}

View File

@ -1,14 +0,0 @@
package com.moyz.adi.common.model;
import lombok.Data;
@Data
public class CostStat {
private int day;
private int textRequestTimesByDay;
private int textTokenCostByDay;
private int imageGeneratedNumberByDay;
private int textTokenCostByMonth;
private int textRequestTimesByMonth;
private int imageGeneratedNumberByMonth;
}

View File

@ -1,11 +0,0 @@
package com.moyz.adi.common.model;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class QuestionMeta {
private Integer tokens;
private String uuid;
}

View File

@ -1,13 +0,0 @@
package com.moyz.adi.common.model;
import lombok.Data;
@Data
public class RequestRateLimit {
private int times;
private int minutes;
private int type;
public static final int TYPE_TEXT = 1;
public static final int TYPE_IMAGE = 2;
}

View File

@ -10,7 +10,7 @@ import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.OpenAiHelper; import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.QuotaHelper;
import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.helper.RateLimitHelper;
import com.moyz.adi.common.mapper.AiImageMapper; import com.moyz.adi.common.mapper.AiImageMapper;
@ -46,10 +46,6 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
@Resource @Resource
@Lazy @Lazy
private AiImageService _this; private AiImageService _this;
@Resource
private OpenAiHelper openAiHelper;
@Resource @Resource
private QuotaHelper quotaHelper; private QuotaHelper quotaHelper;
@ -173,13 +169,14 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.IMAGE_RATE_LIMIT_CONFIG); rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.IMAGE_RATE_LIMIT_CONFIG);
ImageModelContext modelContext = new ImageModelContext();
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
if (aiImage.getInteractingMethod() == INTERACTING_METHOD_GENERATE_IMAGE) { if (aiImage.getInteractingMethod() == INTERACTING_METHOD_GENERATE_IMAGE) {
images = openAiHelper.createImage(user, aiImage); images = modelContext.getModelService().createImage(user, aiImage);
} else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_EDIT_IMAGE) { } else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_EDIT_IMAGE) {
images = openAiHelper.editImage(user, aiImage); images = modelContext.getModelService().editImage(user, aiImage);
} else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_VARIATION) { } else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_VARIATION) {
images = openAiHelper.createImageVariation(user, aiImage); images = modelContext.getModelService().createImageVariation(user, aiImage);
} }
List<String> imageUuids = new ArrayList(); List<String> imageUuids = new ArrayList();
images.forEach(imageUrl -> { images.forEach(imageUrl -> {

View File

@ -1,44 +0,0 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.enums.AiModelStatus;
import com.moyz.adi.common.entity.AiModel;
import com.moyz.adi.common.helper.OpenAiHelper;
import com.moyz.adi.common.mapper.AiModelMapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Service
public class AiModelService extends ServiceImpl<AiModelMapper, AiModel> {
public final static List<AiModel> AI_MODELS = new ArrayList<>();
@Resource
private OpenAiHelper openAiHelper;
@PostConstruct
public void init() {
List<AiModel> aiModels = this.lambdaQuery().eq(AiModel::getModelStatus, AiModelStatus.ACTIVE).list();
AI_MODELS.addAll(aiModels);
//get models from openai
// List<Model> openaiModels = openAiHelper.getModels();
// for (Model model : openaiModels) {
// AiModel aiModel = this.lambdaQuery().eq(AiModel::getName, model.getId()).one();
// if (null == aiModel) {
// aiModel = new AiModel();
// aiModel.setName(model.getId());
// aiModel.setModelStatus(AiModelStatus.INACTIVE);
// baseMapper.insert(aiModel);
// }
// }
//refresh models cache
// aiModels = this.lambdaQuery().eq(AiModel::getModelStatus, AiModelStatus.ACTIVE).list();
// AI_MODELS.clear();
// AI_MODELS.addAll(aiModels);
}
}

View File

@ -13,7 +13,7 @@ import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ChatMessageRoleEnum;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.OpenAiHelper; import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.QuotaHelper;
import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.helper.RateLimitHelper;
import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.mapper.ConversationMessageMapper;
@ -59,9 +59,6 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
@Resource @Resource
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;
@Resource
private OpenAiHelper openAiHelper;
@Resource @Resource
private QuotaHelper quotaHelper; private QuotaHelper quotaHelper;
@ -225,9 +222,9 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
} }
openAiHelper.sseAsk(sseAskParams, (response, questionMeta, answerMeta) -> { new LLMContext(askReq.getModelName()).getLLMService().sseChat(sseAskParams, (response, questionMeta, answerMeta) -> {
try { try {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); _this.saveAfterAiResponse(user, askReq, (String) response, (QuestionMeta) questionMeta, (AnswerMeta) answerMeta);
} catch (Exception e) { } catch (Exception e) {
log.error("error:", e); log.error("error:", e);
} finally { } finally {

View File

@ -0,0 +1,55 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.DashScopeSetting;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.net.Proxy;
import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
/**
* 灵积模型服务(DashScope LLM service)
*/
@Slf4j
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
public DashScopeLLMService(String modelName, Proxy proxy) {
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class, proxy);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getApiKey());
}
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
}
return QwenStreamingChatModel.builder()
.apiKey(setting.getApiKey())
.modelName(modelName)
.build();
}
@Override
protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getApiKey())) {
throw new BaseException(B_LLM_SECRET_KEY_NOT_SET);
}
return QwenChatModel.builder()
.apiKey(setting.getApiKey())
.modelName(modelName)
.build();
}
}

View File

@ -1,22 +1,48 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.helper.EmbeddingHelper; import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.LLMContext;
import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.model.openai.OpenAiModelName;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.net.InetSocketAddress;
import java.net.Proxy;
@Service @Service
public class Initializer { public class Initializer {
@Value("${adi.proxy.enable:false}")
protected boolean proxyEnable;
@Value("${adi.proxy.host:0}")
protected String proxyHost;
@Value("${adi.proxy.http-port:0}")
protected int proxyHttpPort;
@Resource @Resource
private SysConfigService sysConfigService; private SysConfigService sysConfigService;
@Resource @Resource
private EmbeddingHelper embeddingHelper; private RAGService ragService;
@PostConstruct @PostConstruct
public void init(){ public void init() {
sysConfigService.reload(); sysConfigService.reload();
embeddingHelper.init();
Proxy proxy = null;
if (proxyEnable) {
proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
}
LLMContext.addLLMService(OpenAiModelName.GPT_3_5_TURBO, new OpenAiLLMService(OpenAiModelName.GPT_3_5_TURBO, proxy));
LLMContext.addLLMService(QwenModelName.QWEN_MAX, new DashScopeLLMService(QwenModelName.QWEN_MAX, proxy));
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2, proxy));
ragService.init();
} }
} }

View File

@ -23,7 +23,13 @@ public class KnowledgeBaseEmbeddingService extends ServiceImpl<KnowledgeBaseEmbe
return result; return result;
} }
public boolean deleteByItemUuid(String kbItemUuid){ /**
* 删除{kbItemUuid}这个知识库条目的向量
*
* @param kbItemUuid 知识库条目uuid
* @return
*/
public boolean deleteByItemUuid(String kbItemUuid) {
return baseMapper.deleteByItemUuid(kbItemUuid); return baseMapper.deleteByItemUuid(kbItemUuid);
} }
} }

View File

@ -10,7 +10,6 @@ import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseItem; import com.moyz.adi.common.entity.KnowledgeBaseItem;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.EmbeddingHelper;
import com.moyz.adi.common.mapper.KnowledgeBaseItemMapper; import com.moyz.adi.common.mapper.KnowledgeBaseItemMapper;
import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.document.Metadata;
@ -32,7 +31,7 @@ import static com.moyz.adi.common.enums.ErrorEnum.*;
public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMapper, KnowledgeBaseItem> { public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMapper, KnowledgeBaseItem> {
@Resource @Resource
private EmbeddingHelper embeddingHelper; private RAGService ragService;
@Resource @Resource
private KnowledgeBaseEmbeddingService knowledgeBaseEmbeddingService; private KnowledgeBaseEmbeddingService knowledgeBaseEmbeddingService;
@ -103,12 +102,20 @@ public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMappe
} }
public boolean embedding(KnowledgeBaseItem one) { /**
* 知识点向量化如向量已存在则先删除
*
* @param kbItem
* @return
*/
public boolean embedding(KnowledgeBaseItem kbItem) {
knowledgeBaseEmbeddingService.deleteByItemUuid(kbItem.getUuid());
Metadata metadata = new Metadata(); Metadata metadata = new Metadata();
metadata.add("kb_uuid", one.getKbUuid()); metadata.add("kb_uuid", kbItem.getKbUuid());
metadata.add("kb_item_uuid", one.getUuid()); metadata.add("kb_item_uuid", kbItem.getUuid());
Document document = new Document(one.getRemark(), metadata); Document document = new Document(kbItem.getRemark(), metadata);
embeddingHelper.getEmbeddingStoreIngestor().ingest(document); ragService.ingest(document);
return true; return true;
} }

View File

@ -9,7 +9,6 @@ import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.KbEditReq; import com.moyz.adi.common.dto.KbEditReq;
import com.moyz.adi.common.entity.*; import com.moyz.adi.common.entity.*;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.EmbeddingHelper;
import com.moyz.adi.common.mapper.KnowledgeBaseMapper; import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
import com.moyz.adi.common.util.BizPager; import com.moyz.adi.common.util.BizPager;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
@ -44,7 +43,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;
@Resource @Resource
private EmbeddingHelper embeddingHelper; private RAGService ragService;
@Resource @Resource
private KnowledgeBaseItemService knowledgeBaseItemService; private KnowledgeBaseItemService knowledgeBaseItemService;
@ -148,7 +147,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
.add("kb_uuid", knowledgeBase.getUuid()) .add("kb_uuid", knowledgeBase.getUuid())
.add("kb_item_uuid", knowledgeBaseItem.getUuid()); .add("kb_item_uuid", knowledgeBaseItem.getUuid());
embeddingHelper.getEmbeddingStoreIngestor().ingest(docWithoutPath); ragService.ingest(docWithoutPath);
knowledgeBaseItemService knowledgeBaseItemService
.lambdaUpdate() .lambdaUpdate()
@ -194,7 +193,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
.update(); .update();
} }
public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question) { public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question, String modelName) {
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd")); String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
String askTimes = stringRedisTemplate.opsForValue().get(key); String askTimes = stringRedisTemplate.opsForValue().get(key);
@ -205,7 +204,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
stringRedisTemplate.opsForValue().increment(key); stringRedisTemplate.opsForValue().increment(key);
KnowledgeBase knowledgeBase = getOrThrow(kbUuid); KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
String answer = embeddingHelper.findAnswer(kbUuid, question); String answer = ragService.findAnswer(kbUuid, question, modelName);
String uuid = UUID.randomUUID().toString().replace("-", ""); String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord(); KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
newObj.setKbId(knowledgeBase.getId()); newObj.setKbId(knowledgeBase.getId());

View File

@ -0,0 +1,137 @@
package com.moyz.adi.common.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import com.moyz.adi.common.util.ImageUtil;
import com.moyz.adi.common.vo.OpenAiSetting;
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.image.CreateImageEditRequest;
import com.theokanning.openai.image.CreateImageVariationRequest;
import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.service.OpenAiService;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.openai.OpenAiImageModel;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.apache.commons.lang3.StringUtils;
import retrofit2.Retrofit;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_RESP_FORMATS_URL;
import static com.theokanning.openai.service.OpenAiService.defaultClient;
import static com.theokanning.openai.service.OpenAiService.defaultRetrofit;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_512_x_512;
@Slf4j
public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSetting> {
@Resource
private FileService fileService;
@Resource
private ObjectMapper objectMapper;
public OpenAiImageModelService(String modelName, Proxy proxy) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey());
}
@Override
public ImageModel buildImageModel(User user, String size) {
if (StringUtils.isBlank(setting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiImageModel.OpenAiImageModelBuilder builder = OpenAiImageModel.builder()
.modelName(modelName)
.apiKey(setting.getSecretKey())
.user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)
.size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512))
.logRequests(true)
.logResponses(true)
.withPersisting(false)
.maxRetries(2);
if (null != proxy) {
builder.proxy(proxy);
}
return builder.build();
}
@Override
public List<String> editImage(User user, AiImage aiImage) {
File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage()));
File maskFile = null;
if (StringUtils.isNotBlank(aiImage.getMaskImage())) {
maskFile = new File(fileService.getImagePath(aiImage.getMaskImage()));
}
//如果不是RGBA类型的图片先转成RGBA
File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService();
CreateImageEditRequest request = new CreateImageEditRequest();
request.setPrompt(aiImage.getPrompt());
request.setN(aiImage.getGenerateNumber());
request.setSize(aiImage.getGenerateSize());
request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL);
request.setUser(user.getUuid());
try {
ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile);
log.info("editImage response:{}", imageResult);
return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) {
log.error("edit image error", e);
}
return Collections.emptyList();
}
@Override
public List<String> createImageVariation(User user, AiImage aiImage) {
File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService();
CreateImageVariationRequest request = new CreateImageVariationRequest();
request.setN(aiImage.getGenerateNumber());
request.setSize(aiImage.getGenerateSize());
request.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL);
request.setUser(user.getUuid());
try {
ImageResult imageResult = service.createImageVariation(request, imagePath);
log.info("createImageVariation response:{}", imageResult);
return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) {
log.error("image variation error", e);
}
return Collections.emptyList();
}
public OpenAiService getOpenAiService() {
String secretKey = setting.getSecretKey();
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS))
.newBuilder()
.proxy(proxy)
.build();
Retrofit retrofit = defaultRetrofit(client, objectMapper);
OpenAiApi api = retrofit.create(OpenAiApi.class);
return new OpenAiService(api);
}
return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS));
}
}

View File

@ -0,0 +1,65 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import com.moyz.adi.common.vo.OpenAiSetting;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.net.Proxy;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
/**
* OpenAi LLM service
*/
@Slf4j
@Accessors(chain = true)
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
public OpenAiLLMService(String modelName, Proxy proxy) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy);
}
@Override
public boolean isEnabled() {
return StringUtils.isNotBlank(setting.getSecretKey());
}
@Override
protected StreamingChatLanguageModel buildStreamingChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel
.builder()
.modelName(modelName)
.apiKey(setting.getSecretKey())
.timeout(Duration.of(60, ChronoUnit.SECONDS));
if (null != proxy) {
builder.proxy(proxy);
}
return builder.build();
}
@Override
protected ChatLanguageModel buildChatLLM() {
if (StringUtils.isBlank(setting.getSecretKey())) {
throw new BaseException(ErrorEnum.B_LLM_SECRET_KEY_NOT_SET);
}
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(setting.getSecretKey());
if (null != proxy) {
builder.proxy(proxy);
}
return builder.build();
}
}

View File

@ -1,30 +1,25 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.service;
import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters; import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.regex.Matcher; import java.util.regex.Matcher;
@ -35,8 +30,7 @@ import static java.util.stream.Collectors.joining;
@Slf4j @Slf4j
@Service @Service
public class EmbeddingHelper { public class RAGService {
@Value("${spring.datasource.url}") @Value("${spring.datasource.url}")
private String dataBaseUrl; private String dataBaseUrl;
@ -45,32 +39,16 @@ public class EmbeddingHelper {
@Value("${spring.datasource.password}") @Value("${spring.datasource.password}")
private String dataBasePassword; private String dataBasePassword;
@Value("${openai.proxy.enable:false}")
private boolean proxyEnable;
@Value("${openai.proxy.host:0}")
private String proxyHost;
@Value("${openai.proxy.http-port:0}")
private int proxyHttpPort;
private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}"); private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}");
@Resource
private OpenAiHelper openAiHelper;
private EmbeddingModel embeddingModel; private EmbeddingModel embeddingModel;
private EmbeddingStore<TextSegment> embeddingStore; private EmbeddingStore<TextSegment> embeddingStore;
private ChatLanguageModel chatLanguageModel;
public void init() { public void init() {
log.info("initEmbeddingModel"); log.info("initEmbeddingModel");
embeddingModel = new AllMiniLmL6V2EmbeddingModel(); embeddingModel = new AllMiniLmL6V2EmbeddingModel();
embeddingStore = initEmbeddingStore(); embeddingStore = initEmbeddingStore();
chatLanguageModel = initChatLanguageModel();
} }
private EmbeddingStore<TextSegment> initEmbeddingStore() { private EmbeddingStore<TextSegment> initEmbeddingStore() {
@ -107,16 +85,7 @@ public class EmbeddingHelper {
return embeddingStore; return embeddingStore;
} }
private ChatLanguageModel initChatLanguageModel() { private EmbeddingStoreIngestor getEmbeddingStoreIngestor() {
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(openAiHelper.getSecretKey());
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
builder.proxy(proxy);
}
return builder.build();
}
public EmbeddingStoreIngestor getEmbeddingStoreIngestor() {
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO)); DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder() EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
.documentSplitter(documentSplitter) .documentSplitter(documentSplitter)
@ -126,7 +95,24 @@ public class EmbeddingHelper {
return embeddingStoreIngestor; return embeddingStoreIngestor;
} }
public String findAnswer(String kbUuid, String question) { /**
* 对文档切块并向量化
*
* @param document 知识库文档
*/
public void ingest(Document document) {
getEmbeddingStoreIngestor().ingest(document);
}
/**
* 召回并搜索
*
* @param kbUuid 知识库uuid
* @param question 用户的问题
* @param modelName LLM model name
* @return
*/
public String findAnswer(String kbUuid, String question, String modelName) {
// Embed the question // Embed the question
Embedding questionEmbedding = embeddingModel.embed(question).content(); Embedding questionEmbedding = embeddingModel.embed(question).content();
@ -147,10 +133,6 @@ public class EmbeddingHelper {
} }
Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
AiMessage aiMessage = chatLanguageModel.generate(prompt.toUserMessage()).content(); return new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
// See an answer from the model
return aiMessage.text();
} }
} }

View File

@ -57,10 +57,6 @@ public class SysConfigService extends ServiceImpl<SysConfigMapper, SysConfig> {
return Integer.parseInt(maxNum); return Integer.parseInt(maxNum);
} }
public static String getSecretKey() {
return LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.SECRET_KEY);
}
public static String getByKey(String key) { public static String getByKey(String key) {
return LocalCache.CONFIGS.get(key); return LocalCache.CONFIGS.get(key);
} }

View File

@ -98,7 +98,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
.eq(User::getEmail, email) .eq(User::getEmail, email)
.one(); .one();
if (null != user && user.getUserStatus() == UserStatusEnum.NORMAL) { if (null != user && user.getUserStatus() == UserStatusEnum.NORMAL) {
throw new BaseException(A_REGISTER_USER_EXIST); throw new BaseException(A_USER_EXIST);
} }
if (null != user) { if (null != user) {
sendActiveEmail(email); sendActiveEmail(email);
@ -112,7 +112,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
//创建用户 //创建用户
User newOne = new User(); User newOne = new User();
newOne.setName(StringUtils.substringBetween(email, "@")); newOne.setName(StringUtils.substringBefore(email, "@"));
newOne.setUuid(UUID.randomUUID().toString().replace("-", "")); newOne.setUuid(UUID.randomUUID().toString().replace("-", ""));
newOne.setEmail(email); newOne.setEmail(email);
newOne.setPassword(hashed); newOne.setPassword(hashed);

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.vo;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class DashScopeSetting {
@JsonProperty("api_key")
private String apiKey;
}

View File

@ -0,0 +1,12 @@
package com.moyz.adi.common.vo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.moyz.adi.common.interfaces.AbstractImageModelService;
import lombok.Data;
@Data
public class ImageModelInfo extends ModelInfo {
@JsonIgnore
private AbstractImageModelService modelService;
}

View File

@ -0,0 +1,12 @@
package com.moyz.adi.common.vo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.moyz.adi.common.interfaces.AbstractLLMService;
import lombok.Data;
@Data
public class LLMModelInfo extends ModelInfo {
@JsonIgnore
private AbstractLLMService llmService;
}

View File

@ -0,0 +1,9 @@
package com.moyz.adi.common.vo;
import lombok.Data;
@Data
public class ModelInfo {
private String modelName;
private Boolean enable;
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.vo;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class OpenAiSetting {
@JsonProperty("secret_key")
private String secretKey;
}

View File

@ -1,7 +1,6 @@
package com.moyz.adi.common.vo; package com.moyz.adi.common.vo;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.util.TriConsumer;
import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.ChatMemory;
import lombok.Data; import lombok.Data;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

View File

@ -3,10 +3,23 @@
<mapper namespace="com.moyz.adi.common.mapper.KnowledgeBaseEmbeddingMapper"> <mapper namespace="com.moyz.adi.common.mapper.KnowledgeBaseEmbeddingMapper">
<select id="selectByItemUuid" resultType="com.moyz.adi.common.entity.KnowledgeBaseEmbedding"> <select id="selectByItemUuid" resultType="com.moyz.adi.common.entity.KnowledgeBaseEmbedding">
select * from adi_knowledge_base_embedding where metadata->>'kb_item_uuid' = #{kbItemUuid} select *
from adi_knowledge_base_embedding
where metadata ->> 'kb_item_uuid' = #{kbItemUuid}
</select> </select>
<delete id="deleteByIds">
delete from adi_knowledge_base_embedding where embedding_id in
<foreach collection="ids" open="(" separator="," close=")" item="id">
<if test="id != null and id != ''">
#{id}
</if>
</foreach>
</delete>
<delete id="deleteByItemUuid"> <delete id="deleteByItemUuid">
delete from adi_knowledge_base_embedding where metadata->>'kb_item_uuid' = #{kbItemUuid} delete
from adi_knowledge_base_embedding
where metadata ->> 'kb_item_uuid' = #{kbItemUuid}
</delete> </delete>
</mapper> </mapper>

View File

@ -201,7 +201,7 @@ CREATE TABLE public.adi_sys_config
( (
id bigserial primary key, id bigserial primary key,
name character varying(100) DEFAULT ''::character varying NOT NULL, name character varying(100) DEFAULT ''::character varying NOT NULL,
value character varying(100) DEFAULT ''::character varying NOT NULL, value character varying(1000) DEFAULT ''::character varying NOT NULL,
create_time timestamp DEFAULT localtimestamp NOT NULL, create_time timestamp DEFAULT localtimestamp NOT NULL,
update_time timestamp DEFAULT localtimestamp NOT NULL, update_time timestamp DEFAULT localtimestamp NOT NULL,
is_deleted boolean DEFAULT false NOT NULL is_deleted boolean DEFAULT false NOT NULL
@ -382,7 +382,9 @@ CREATE TRIGGER trigger_user_day_cost_update_time
EXECUTE PROCEDURE update_modified_column(); EXECUTE PROCEDURE update_modified_column();
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('secret_key', ''); VALUES ('openai_setting', '{"secret_key":""}');
INSERT INTO adi_sys_config (name, value)
VALUES ('dashscope_setting', '{"api_key":""}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}'); VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)

View File

@ -25,7 +25,7 @@
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target> <maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<langchain4j.version>0.25.0</langchain4j.version> <langchain4j.version>0.27.1</langchain4j.version>
</properties> </properties>
<dependencies> <dependencies>
<dependency> <dependency>
@ -160,6 +160,11 @@
<artifactId>langchain4j-document-parser-apache-poi</artifactId> <artifactId>langchain4j-document-parser-apache-poi</artifactId>
<version>${langchain4j.version}</version> <version>${langchain4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId> <artifactId>spring-boot-starter-test</artifactId>