From 86a70e09caeb546ea816b20eff80b7db756f768c Mon Sep 17 00:00:00 2001 From: moyangzhan Date: Mon, 8 Apr 2024 00:17:23 +0800 Subject: [PATCH] add ai search --- README.md | 36 ++- .../src/main/resources/application.yml | 3 + .../adi/chat/controller/AuthController.java | 9 + .../controller/KnowledgeBaseController.java | 2 +- .../controller/KnowledgeBaseQAController.java | 2 +- .../adi/chat/controller/SearchController.java | 29 ++ .../controller/SearchRecordController.java | 28 ++ .../adi/common/base/JsonNodeTypeHandler.java | 73 +++++ .../base/SearchEngineRespTypeHandler.java | 70 +++++ .../moyz/adi/common/config/BeanConfig.java | 34 ++- .../moyz/adi/common/cosntant/AdiConstant.java | 40 ++- .../adi/common/dto/AiSearchRecordResp.java | 33 +++ .../com/moyz/adi/common/dto/AiSearchReq.java | 19 ++ .../com/moyz/adi/common/dto/AiSearchResp.java | 13 + .../adi/common/dto/GoogleSearchError.java | 9 + .../moyz/adi/common/dto/GoogleSearchResp.java | 47 ++++ .../moyz/adi/common/dto/SearchEngineResp.java | 12 + .../com/moyz/adi/common/dto/SearchResult.java | 11 + .../moyz/adi/common/dto/SearchResultItem.java | 11 + .../adi/common/entity/AiSearchRecord.java | 50 ++++ .../com/moyz/adi/common/enums/ErrorEnum.java | 2 +- .../moyz/adi/common/helper/QuotaHelper.java | 8 +- .../adi/common/helper/SSEEmitterHelper.java | 52 ++-- .../interfaces/AbstractImageModelService.java | 9 +- .../common/interfaces/AbstractLLMService.java | 14 +- .../interfaces/AbstractSearchEngine.java | 45 ++++ .../interfaces/AbstractSearchService.java | 11 + .../common/mapper/AiSearchRecordMapper.java | 9 + .../searchengine/GoogleSearchEngine.java | 56 ++++ .../searchengine/SearchEngineContext.java | 40 +++ .../common/service/AiSearchRecordService.java | 82 ++++++ .../service/ConversationMessageService.java | 20 +- .../common/service/DashScopeLLMService.java | 2 +- .../moyz/adi/common/service/Initializer.java | 18 +- .../service/KnowledgeBaseItemService.java | 6 +- .../service/KnowledgeBaseQaRecordService.java | 2 +- .../common/service/KnowledgeBaseService.java | 32 ++- .../adi/common/service/OllamaLLMService.java | 3 +- .../service/OpenAiImageModelService.java | 4 +- .../adi/common/service/OpenAiLLMService.java | 5 +- .../adi/common/service/QianFanLLMService.java | 2 +- .../moyz/adi/common/service/RAGService.java | 52 ++-- .../adi/common/service/SearchService.java | 250 ++++++++++++++++++ .../util/AdiPgVectorEmbeddingStore.java | 19 +- .../com/moyz/adi/common/util/BizPager.java | 16 ++ .../com/moyz/adi/common/util/SpringUtil.java | 25 ++ .../com/moyz/adi/common/vo/GoogleSetting.java | 10 + .../moyz/adi/common/vo/SearchEngineInfo.java | 13 + docs/create.sql | 48 ++++ 49 files changed, 1280 insertions(+), 106 deletions(-) create mode 100644 adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchController.java create mode 100644 adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchRecordController.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/base/JsonNodeTypeHandler.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/base/SearchEngineRespTypeHandler.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchRecordResp.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchReq.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchResp.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchError.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchResp.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/SearchEngineResp.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/SearchResult.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/dto/SearchResultItem.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/entity/AiSearchRecord.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchEngine.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchService.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/mapper/AiSearchRecordMapper.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/searchengine/GoogleSearchEngine.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/searchengine/SearchEngineContext.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/service/AiSearchRecordService.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/service/SearchService.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/util/SpringUtil.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/vo/GoogleSetting.java create mode 100644 adi-common/src/main/java/com/moyz/adi/common/vo/SearchEngineInfo.java diff --git a/README.md b/README.md index 45b0cd4..4b360f3 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,9 @@ * 提示词 * 额度控制 * 基于大模型的知识库(RAG) +* 基于大模型的搜索(RAG) * 多模型随意切换 +* 多搜索引擎随意切换 ## 接入的模型: @@ -27,6 +29,14 @@ * ollama * DALL-E 2 +## 接入的搜索引擎 + +Google + +Bing (TODO) + +百度 (TODO) + ## 技术栈 该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) @@ -53,7 +63,7 @@ vue3+typescript+pnpm * 创建数据库aideepin * 执行docs/create.sql -* 填充各模型的配置 +* 填充各模型的配置(至少设置一个) openai的secretKey @@ -79,6 +89,15 @@ ollama的配置 update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting'; ``` +* 填充搜索引擎的配置 + +Google的配置 + +``` +update adi_sys_config set value = '{"url":"https://www.googleapis.com/customsearch/v1","key":"my key from cloud.google.com","cx":"my cx from programmablesearchengine.google.com"}' where name = 'google_setting'; +``` + + **b. 修改配置文件** * postgresql: application-[dev|prod].xml中的spring.datasource @@ -122,23 +141,30 @@ docker run -d \ ## 待办: -* AI搜索 -* 增强RAG +增强RAG + +增加搜索引擎(BING、百度) ## 截图 **AI聊天:** ![1691583184761](image/README/1691583184761.png) -![1691583124744](image/README/1691583124744.png "AI绘图") +**AI画图:** -![1691583329105](image/README/1691583329105.png "token统计") +![1691583124744](image/README/1691583124744.png "AI绘图") **知识库:** ![kbindex](image/README/kbidx.png) ![kb01](image/README/kb01.png) +**向量化:** + ![kb02](image/README/kb02.png) ![kb03](image/README/kb03.png) + +**额度统计:** + +![1691583329105](https://file+.vscode-resource.vscode-cdn.net/e%3A/WORKSPACE/aideepin/image/README/1691583329105.png "token统计") diff --git a/adi-bootstrap/src/main/resources/application.yml b/adi-bootstrap/src/main/resources/application.yml index 4174f9f..ac9f8e7 100644 --- a/adi-bootstrap/src/main/resources/application.yml +++ b/adi-bootstrap/src/main/resources/application.yml @@ -11,6 +11,9 @@ spring: name: AiDeepIn profiles: active: dev + mvc: + async: + request-timeout: 60000 jackson: date-format: "yyyy-MM-dd HH:mm:ss" time-zone: "GMT+8" diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/AuthController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/AuthController.java index 7e5b2e6..fe68602 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/AuthController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/AuthController.java @@ -3,7 +3,9 @@ package com.moyz.adi.chat.controller; import com.moyz.adi.common.dto.LoginReq; import com.moyz.adi.common.dto.LoginResp; import com.moyz.adi.common.dto.RegisterReq; +import com.moyz.adi.common.searchengine.SearchEngineContext; import com.moyz.adi.common.service.UserService; +import com.moyz.adi.common.vo.SearchEngineInfo; import com.ramostear.captcha.HappyCaptcha; import com.ramostear.captcha.support.CaptchaType; import io.swagger.v3.oas.annotations.Operation; @@ -23,6 +25,8 @@ import org.springframework.web.bind.annotation.*; import java.io.IOException; import java.net.URLEncoder; +import java.util.List; +import java.util.stream.Collectors; import static org.springframework.http.HttpHeaders.AUTHORIZATION; @@ -123,4 +127,9 @@ public class AuthController { happyCaptcha.output(); } + @Operation(summary = "Search engine list") + @GetMapping(value = "/search-engine/list") + public List engines() { + return SearchEngineContext.NAME_TO_ENGINE.values().stream().collect(Collectors.toList()); + } } diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseController.java index c251479..6e0dfcf 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseController.java @@ -70,7 +70,7 @@ public class KnowledgeBaseController { * * @return */ - @PostMapping("/star/{uuid}") + @PostMapping("/star/{kbUuid}") public boolean star(@PathVariable String kbUuid) { return knowledgeBaseService.star(kbUuid); } diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java index d15fcc7..40ad4c4 100644 --- a/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/KnowledgeBaseQAController.java @@ -44,6 +44,6 @@ public class KnowledgeBaseQAController { @PostMapping("/record/del/{uuid}") public boolean recordDel(@PathVariable String uuid) { - return knowledgeBaseQaRecordService.softDelele(uuid); + return knowledgeBaseQaRecordService.softDelete(uuid); } } diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchController.java new file mode 100644 index 0000000..833180f --- /dev/null +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchController.java @@ -0,0 +1,29 @@ +package com.moyz.adi.chat.controller; + +import com.moyz.adi.common.dto.AiSearchReq; +import com.moyz.adi.common.service.SearchService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import org.springframework.http.MediaType; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +@Tag(name = "AI search controller") +@RequestMapping("/ai-search/") +@RestController +public class SearchController { + + @Resource + private SearchService searchService; + + @Operation(summary = "sse process") + @PostMapping(value = "/process", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter sseAsk(@RequestBody @Validated AiSearchReq req) { + return searchService.search(req.isBriefSearch(), req.getSearchText(), req.getEngineName(), req.getModelName()); + } +} diff --git a/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchRecordController.java b/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchRecordController.java new file mode 100644 index 0000000..a083402 --- /dev/null +++ b/adi-chat/src/main/java/com/moyz/adi/chat/controller/SearchRecordController.java @@ -0,0 +1,28 @@ +package com.moyz.adi.chat.controller; + +import com.moyz.adi.common.dto.AiSearchResp; +import com.moyz.adi.common.service.AiSearchRecordService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import org.springframework.web.bind.annotation.*; + +@Tag(name = "Ai search record controller") +@RequestMapping("/ai-search-record/") +@RestController +public class SearchRecordController { + + @Resource + private AiSearchRecordService aiSearchRecordService; + + @Operation(summary = "List by max id") + @GetMapping(value = "/list") + public AiSearchResp list(@RequestParam(defaultValue = "0") Long maxId, String keyword) { + return aiSearchRecordService.listByMaxId(maxId, keyword); + } + + @PostMapping("/del/{uuid}") + public boolean recordDel(@PathVariable String uuid) { + return aiSearchRecordService.softDelete(uuid); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/base/JsonNodeTypeHandler.java b/adi-common/src/main/java/com/moyz/adi/common/base/JsonNodeTypeHandler.java new file mode 100644 index 0000000..b383dfa --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/base/JsonNodeTypeHandler.java @@ -0,0 +1,73 @@ +package com.moyz.adi.common.base; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.ibatis.type.BaseTypeHandler; +import org.apache.ibatis.type.JdbcType; +import org.apache.ibatis.type.MappedJdbcTypes; +import org.apache.ibatis.type.MappedTypes; +import org.postgresql.util.PGobject; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +@MappedJdbcTypes({JdbcType.JAVA_OBJECT}) +@MappedTypes({JsonNode.class}) +public class JsonNodeTypeHandler extends BaseTypeHandler { + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public void setNonNullParameter(PreparedStatement ps, int i, JsonNode parameter, JdbcType jdbcType) + throws SQLException { + PGobject jsonObject = new PGobject(); + jsonObject.setType("jsonb"); + try { + jsonObject.setValue(parameter.toString()); + ps.setObject(i, jsonObject); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public JsonNode getNullableResult(ResultSet rs, String columnName) throws SQLException { + String jsonSource = rs.getString(columnName); + if (jsonSource != null) { + try { + return objectMapper.readTree(jsonSource); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } + + @Override + public JsonNode getNullableResult(ResultSet rs, int columnIndex) throws SQLException { + String jsonSource = rs.getString(columnIndex); + if (jsonSource != null) { + try { + return objectMapper.readTree(jsonSource); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } + + @Override + public JsonNode getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { + String jsonSource = cs.getString(columnIndex); + if (jsonSource != null) { + try { + return objectMapper.readTree(jsonSource); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } +} \ No newline at end of file diff --git a/adi-common/src/main/java/com/moyz/adi/common/base/SearchEngineRespTypeHandler.java b/adi-common/src/main/java/com/moyz/adi/common/base/SearchEngineRespTypeHandler.java new file mode 100644 index 0000000..245b104 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/base/SearchEngineRespTypeHandler.java @@ -0,0 +1,70 @@ +package com.moyz.adi.common.base; + +import com.moyz.adi.common.dto.SearchEngineResp; +import com.moyz.adi.common.util.JsonUtil; +import org.apache.ibatis.type.BaseTypeHandler; +import org.apache.ibatis.type.JdbcType; +import org.apache.ibatis.type.MappedJdbcTypes; +import org.apache.ibatis.type.MappedTypes; +import org.postgresql.util.PGobject; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +@MappedJdbcTypes({JdbcType.JAVA_OBJECT}) +@MappedTypes({SearchEngineResp.class}) +public class SearchEngineRespTypeHandler extends BaseTypeHandler { + + @Override + public void setNonNullParameter(PreparedStatement ps, int i, SearchEngineResp parameter, JdbcType jdbcType) { + PGobject jsonObject = new PGobject(); + jsonObject.setType("jsonb"); + try { + jsonObject.setValue(JsonUtil.toJson(parameter)); + ps.setObject(i, jsonObject); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public SearchEngineResp getNullableResult(ResultSet rs, String columnName) throws SQLException { + String jsonSource = rs.getString(columnName); + if (jsonSource != null) { + try { + return JsonUtil.fromJson(jsonSource, SearchEngineResp.class); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } + + @Override + public SearchEngineResp getNullableResult(ResultSet rs, int columnIndex) throws SQLException { + String jsonSource = rs.getString(columnIndex); + if (jsonSource != null) { + try { + return JsonUtil.fromJson(jsonSource, SearchEngineResp.class); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } + + @Override + public SearchEngineResp getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { + String jsonSource = cs.getString(columnIndex); + if (jsonSource != null) { + try { + return JsonUtil.fromJson(jsonSource, SearchEngineResp.class); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return null; + } +} \ No newline at end of file diff --git a/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java b/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java index 6f2346c..6a5b8a1 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java +++ b/adi-common/src/main/java/com/moyz/adi/common/config/BeanConfig.java @@ -10,11 +10,14 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import com.google.common.collect.Lists; +import com.moyz.adi.common.base.SearchEngineRespTypeHandler; +import com.moyz.adi.common.dto.SearchEngineResp; +import com.moyz.adi.common.service.RAGService; import com.moyz.adi.common.util.LocalDateTimeUtil; import com.pgvector.PGvector; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.session.SqlSessionFactory; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Primary; @@ -32,6 +35,15 @@ import javax.sql.DataSource; @Configuration public class BeanConfig { + @Value("${spring.datasource.url}") + private String dataBaseUrl; + + @Value("${spring.datasource.username}") + private String dataBaseUserName; + + @Value("${spring.datasource.password}") + private String dataBasePassword; + @Bean public RestTemplate restTemplate() { log.info("Configuration:create restTemplate"); @@ -42,7 +54,7 @@ public class BeanConfig { requestFactory.setReadTimeout(60000); RestTemplate restTemplate = new RestTemplate(); // 注册LOG拦截器 - restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor())); +// restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor())); restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(requestFactory)); return restTemplate; @@ -95,12 +107,28 @@ public class BeanConfig { bean.setMapperLocations( new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml")); MybatisConfiguration configuration = bean.getConfiguration(); - if(null == configuration){ + if (null == configuration) { configuration = new MybatisConfiguration(); bean.setConfiguration(configuration); } bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class); + bean.getConfiguration().getTypeHandlerRegistry().register(SearchEngineResp.class, SearchEngineRespTypeHandler.class); return bean.getObject(); } + @Bean(name = "kbRagService") + @Primary + public RAGService initKnowledgeBaseRAGService() { + RAGService ragService = new RAGService("adi_knowledge_base_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword); + ragService.init(); + return ragService; + } + + @Bean(name = "searchRagService") + public RAGService initSearchRAGService() { + RAGService ragService = new RAGService("adi_ai_search_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword); + ragService.init(); + return ragService; + } + } diff --git a/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java b/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java index f646f2c..79d2b71 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java +++ b/adi-common/src/main/java/com/moyz/adi/common/cosntant/AdiConstant.java @@ -1,10 +1,12 @@ package com.moyz.adi.common.cosntant; +import dev.langchain4j.model.input.PromptTemplate; + import java.util.List; public class AdiConstant { - public static final int DEFAULT_PAGE_SIZE = 1; + public static final int DEFAULT_PAGE_SIZE = 10; /** * 验证码id过期时间:1小时 @@ -53,6 +55,13 @@ public class AdiConstant { public static final List OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024"); + public static final PromptTemplate PROMPT_TEMPLATE = PromptTemplate.from(""" + 根据以下已知信息: + {{information}} + 尽可能准确地回答用户的问题,以下是用户的问题: + {{question}} + 注意,回答的内容不能让用户感知到已知信息的存在 + """); public static class GenerateImage { public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1; @@ -64,11 +73,21 @@ public class AdiConstant { public static final int STATUS_SUCCESS = 3; } + public static class EmbeddingMetadataKey { + public static final String KB_UUID = "kb_uuid"; + public static final String KB_ITEM_UUID = "kb_item_uuid"; + public static final String ENGINE_NAME = "engine_name"; + public static final String SEARCH_UUID = "search_uuid"; + } + public static class SysConfigKey { public static final String OPENAI_SETTING = "openai_setting"; public static final String DASHSCOPE_SETTING = "dashscope_setting"; public static final String QIANFAN_SETTING = "qianfan_setting"; public static final String OLLAMA_SETTING = "ollama_setting"; + public static final String GOOGLE_SETTING = "google_setting"; + public static final String BING_SETTING = "bing_setting"; + public static final String BAIDU_SETTING = "baidu_setting"; public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit"; public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit"; public static final String CONVERSATION_MAX_NUM = "conversation_max_num"; @@ -78,8 +97,25 @@ public class AdiConstant { public static final String QUOTA_BY_REQUEST_MONTHLY = "quota_by_request_monthly"; public static final String QUOTA_BY_IMAGE_DAILY = "quota_by_image_daily"; public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly"; - + public static final String QUOTA_BY_QA_ASK_DAILY = "quota_by_qa_ask_daily"; } public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"}; + + public static class SearchEngineName { + public static final String GOOGLE = "google"; + public static final String BING = "bing"; + public static final String BAIDU = "baidu"; + } + + public static class SSEEventName { + public static final String START = "[START]"; + public static final String DONE = "[DONE]"; + public static final String ERROR = "[ERROR]"; + + public static final String AI_SEARCH_SOURCE_LINKS = "[SOURCE_LINKS]"; + } + + public static final int RAG_TYPE_KB = 1; + public static final int RAG_TYPE_SEARCH = 2; } diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchRecordResp.java b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchRecordResp.java new file mode 100644 index 0000000..c8b492c --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchRecordResp.java @@ -0,0 +1,33 @@ +package com.moyz.adi.common.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.time.LocalDateTime; + +@NoArgsConstructor +@AllArgsConstructor +@Builder +@Data +public class AiSearchRecordResp { + + private String uuid; + + private String question; + + private SearchEngineResp searchEngineResp; + + private String prompt; + + private Integer promptTokens; + + private String answer; + + private Integer answerTokens; + + private String userUuid; + + private LocalDateTime createTime; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchReq.java b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchReq.java new file mode 100644 index 0000000..bc9d8fb --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchReq.java @@ -0,0 +1,19 @@ +package com.moyz.adi.common.dto; + +import jakarta.validation.constraints.NotBlank; +import lombok.Data; +import org.springframework.validation.annotation.Validated; + +@Validated +@Data +public class AiSearchReq { + + @NotBlank + private String searchText; + + private String engineName; + + private String modelName; + + private boolean briefSearch; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchResp.java b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchResp.java new file mode 100644 index 0000000..6034909 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/AiSearchResp.java @@ -0,0 +1,13 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; + +import java.util.List; + +@Data +public class AiSearchResp { + + private Long minId; + + private List records; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchError.java b/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchError.java new file mode 100644 index 0000000..54f0070 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchError.java @@ -0,0 +1,9 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; + +@Data +public class GoogleSearchError { + private Integer code; + private String message; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchResp.java b/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchResp.java new file mode 100644 index 0000000..ad0d443 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/GoogleSearchResp.java @@ -0,0 +1,47 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; + +import java.util.List; + +@Data +public class GoogleSearchResp { + private String kind; + private Queries queries; + private SearchInformation searchInformation; + private List items; + private GoogleSearchError error; + + @Data + public static class Queries { + private Request[] request; + } + + @Data + public static class Request { + private String title; + private String totalResults; + private String searchTerms; + private Integer count; + private Integer startIndex; + private String inputEncoding; + private String outputEncoding; + } + + @Data + public static class SearchInformation { + private double searchTime; + private String formattedSearchTime; + private String totalResults; + private String formattedTotalResults; + } + + @Data + public static class Item { + private String kind; + private String title; + private String htmlTitle; + private String link; + private String snippet; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/SearchEngineResp.java b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchEngineResp.java new file mode 100644 index 0000000..71b495d --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchEngineResp.java @@ -0,0 +1,12 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; +import lombok.experimental.Accessors; + +import java.util.List; + +@Data +@Accessors(chain = true) +public class SearchEngineResp { + private List items; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResult.java b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResult.java new file mode 100644 index 0000000..1c5b4dd --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResult.java @@ -0,0 +1,11 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; + +import java.util.List; + +@Data +public class SearchResult { + private String errorMessage; + private List items; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResultItem.java b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResultItem.java new file mode 100644 index 0000000..c7ff032 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/dto/SearchResultItem.java @@ -0,0 +1,11 @@ +package com.moyz.adi.common.dto; + +import lombok.Data; + +@Data +public class SearchResultItem { + private String title; + private String link; + private String snippet; + private String content; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/entity/AiSearchRecord.java b/adi-common/src/main/java/com/moyz/adi/common/entity/AiSearchRecord.java new file mode 100644 index 0000000..87419b1 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/entity/AiSearchRecord.java @@ -0,0 +1,50 @@ +package com.moyz.adi.common.entity; + +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableName; +import com.moyz.adi.common.base.SearchEngineRespTypeHandler; +import com.moyz.adi.common.dto.SearchEngineResp; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import org.apache.ibatis.type.JdbcType; + +@Data +@TableName("adi_ai_search_record") +@Schema(title = "AiSearchRecord对象", description = "AI搜索记录表") +public class AiSearchRecord extends BaseEntity { + + @TableField("uuid") + private String uuid; + + @Schema(title = "问题") + @TableField("question") + private String question; + + @Schema(title = "Search engine's response content") + @TableField(value = "search_engine_response", jdbcType = JdbcType.JAVA_OBJECT, typeHandler = SearchEngineRespTypeHandler.class) + private SearchEngineResp searchEngineResp; + + @Schema(title = "最终提供给LLM的提示词") + @TableField("prompt") + private String prompt; + + @Schema(title = "提供给LLM的提示词所消耗的token数量") + @TableField("prompt_tokens") + private Integer promptTokens; + + @Schema(title = "答案") + @TableField("answer") + private String answer; + + @Schema(title = "答案消耗的token") + @TableField("answer_tokens") + private Integer answerTokens; + + @Schema(title = "提问用户uuid") + @TableField("user_uuid") + private String userUuid; + + @Schema(title = "提问用户id") + @TableField("user_id") + private Long userId; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java b/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java index af4e93b..061a9bc 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java +++ b/adi-common/src/main/java/com/moyz/adi/common/enums/ErrorEnum.java @@ -34,7 +34,7 @@ public enum ErrorEnum { B_MESSAGE_NOT_FOUND("B0008", "消息不存在"), B_LLM_SERVICE_DISABLED("B0009", "LLM服务不可用"), B_KNOWLEDGE_BASE_IS_EMPTY("B0010", "知识库内容为空"), - B_KNOWLEDGE_BASE_NO_ANSWER("B0011", "[无答案]") + B_NO_ANSWER("B0011", "[无答案]") ; private String code; diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/QuotaHelper.java b/adi-common/src/main/java/com/moyz/adi/common/helper/QuotaHelper.java index ed82c6f..1ba57a9 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/helper/QuotaHelper.java +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/QuotaHelper.java @@ -17,10 +17,10 @@ public class QuotaHelper { private UserDayCostService userDayCostService; public ErrorEnum checkTextQuota(User user) { - if (StringUtils.isNotBlank(user.getSecretKey())) { - log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId()); - return null; - } +// if (StringUtils.isNotBlank(user.getSecretKey())) { +// log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId()); +// return null; +// } int userQuotaByTokenDay = user.getQuotaByTokenDaily(); int userQuotaByTokenMonth = user.getQuotaByTokenMonthly(); int userQuotaByRequestDay = user.getQuotaByRequestDaily(); diff --git a/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java b/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java index c3e8aa7..ee6e022 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java +++ b/adi-common/src/main/java/com/moyz/adi/common/helper/SSEEmitterHelper.java @@ -1,5 +1,6 @@ package com.moyz.adi.common.helper; +import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.entity.User; import com.moyz.adi.common.interfaces.TriConsumer; @@ -28,35 +29,45 @@ public class SSEEmitterHelper { @Resource private RateLimitHelper rateLimitHelper; - public void process(User user, SseAskParams sseAskParams, TriConsumer consumer) { - SseEmitter sseEmitter = sseAskParams.getSseEmitter(); - - //rate limit by system + public boolean checkOrComplete(User user, SseEmitter sseEmitter) { + //Check: rate limit String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) { - sendErrorMsg(sseEmitter, "访问太过频繁"); - return; + sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁"); + return false; } //Check: If still waiting response String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); String askingVal = stringRedisTemplate.opsForValue().get(askingKey); if (StringUtils.isNotBlank(askingVal)) { - sendErrorMsg(sseEmitter, "正在回复中..."); - return; + sendErrorAndComplete(user.getId(), sseEmitter, "正在回复中..."); + return false; } + return true; + } + + public void startSse(User user, SseEmitter sseEmitter) { + + String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS); + + String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); + rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG); try { - sseEmitter.send(SseEmitter.event().name("[START]")); + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START)); } catch (IOException e) { log.error("error", e); sseEmitter.completeWithError(e); stringRedisTemplate.delete(askingKey); - return; } + } - rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG); + public void processAndPushToModel(User user, SseAskParams sseAskParams, TriConsumer consumer) { + String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); + + SseEmitter sseEmitter = sseAskParams.getSseEmitter(); sseEmitter.onCompletion(() -> { log.info("response complete,uid:{}", user.getId()); }); @@ -65,7 +76,7 @@ public class SSEEmitterHelper { throwable -> { try { log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable); - sseEmitter.send(SseEmitter.event().name("[ERROR]").data(throwable.getMessage())); + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(throwable.getMessage())); } catch (IOException e) { log.error("error", e); } finally { @@ -84,22 +95,29 @@ public class SSEEmitterHelper { }); } - public void sendAndComplete(SseEmitter sseEmitter, String msg){ + public void sendAndComplete(long userId, SseEmitter sseEmitter, String msg) { try { - sseEmitter.send(SseEmitter.event().name("[START]")); - sseEmitter.send(SseEmitter.event().name("[DONE]").data(msg)); + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START)); + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.DONE).data(msg)); } catch (IOException e) { throw new RuntimeException(e); } sseEmitter.complete(); + delSseRequesting(userId); } - public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) { + public void sendErrorAndComplete(long userId, SseEmitter sseEmitter, String errorMsg) { try { - sseEmitter.send(SseEmitter.event().name("[ERROR]").data(errorMsg)); + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(errorMsg)); } catch (IOException e) { throw new RuntimeException(e); } sseEmitter.complete(); + delSseRequesting(userId); + } + + private void delSseRequesting(long userId) { + String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, userId); + stringRedisTemplate.delete(askingKey); } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java index 896ef84..3465a18 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractImageModelService.java @@ -39,13 +39,17 @@ public abstract class AbstractImageModelService { protected ImageModel imageModel; - public AbstractImageModelService(String modelName, String settingName, Class clazz, Proxy proxy){ + public AbstractImageModelService(String modelName, String settingName, Class clazz) { this.modelName = modelName; - this.proxy = proxy; String st = LocalCache.CONFIGS.get(settingName); setting = JsonUtil.fromJson(st, clazz); } + public AbstractImageModelService setProxy(Proxy proxy) { + this.proxy = proxy; + return this; + } + public ImageModel getImageModel(User user, String size) { if (null != imageModel) { return imageModel; @@ -56,6 +60,7 @@ public abstract class AbstractImageModelService { /** * 检测该service是否可用(不可用的情况通过是没有配置key) + * * @return */ public abstract boolean isEnabled(); diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java index dc68260..5611841 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractLLMService.java @@ -36,13 +36,17 @@ public abstract class AbstractLLMService { protected StreamingChatLanguageModel streamingChatLanguageModel; protected ChatLanguageModel chatLanguageModel; - public AbstractLLMService(String modelName, String settingName, Class clazz, Proxy proxy) { + public AbstractLLMService(String modelName, String settingName, Class clazz) { this.modelName = modelName; - this.proxy = proxy; String st = LocalCache.CONFIGS.get(settingName); setting = JsonUtil.fromJson(st, clazz); } + public AbstractLLMService setProxy(Proxy proxy) { + this.proxy = proxy; + return this; + } + /** * 检测该service是否可用(不可用的情况通常是没有配置key) * @@ -73,7 +77,7 @@ public abstract class AbstractLLMService { protected abstract String parseError(Object error); public Response chat(ChatMessage chatMessage) { - if(!isEnabled()){ + if (!isEnabled()) { log.error("llm service is disabled"); throw new BaseException(B_LLM_SERVICE_DISABLED); } @@ -81,7 +85,7 @@ public abstract class AbstractLLMService { } public void sseChat(SseAskParams params, TriConsumer consumer) { - if(!isEnabled()){ + if (!isEnabled()) { log.error("llm service is disabled"); throw new BaseException(B_LLM_SERVICE_DISABLED); } @@ -131,7 +135,7 @@ public abstract class AbstractLLMService { log.error("stream error", error); try { String errorMsg = parseError(error); - if(StringUtils.isBlank(errorMsg)){ + if (StringUtils.isBlank(errorMsg)) { errorMsg = error.getMessage(); } params.getSseEmitter().send(SseEmitter.event().name("[ERROR]").data(errorMsg)); diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchEngine.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchEngine.java new file mode 100644 index 0000000..b250140 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchEngine.java @@ -0,0 +1,45 @@ +package com.moyz.adi.common.interfaces; + +import com.moyz.adi.common.dto.SearchResult; +import com.moyz.adi.common.util.JsonUtil; +import com.moyz.adi.common.util.LocalCache; +import com.moyz.adi.common.util.SpringUtil; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.web.client.RestTemplate; + +import java.net.Proxy; + +public abstract class AbstractSearchEngine { + + protected String engineName; + + protected Proxy proxy; + + public AbstractSearchEngine(String engineName, String settingName, Class clazz) { + this.engineName = engineName; + String st = LocalCache.CONFIGS.get(settingName); + setting = JsonUtil.fromJson(st, clazz); + } + + + protected T setting; + + public abstract boolean isEnabled(); + + public abstract SearchResult search(String searchTxt); + + public AbstractSearchEngine setProxy(Proxy proxy) { + this.proxy = proxy; + return this; + } + + protected RestTemplate getRestTemplate() { + RestTemplate restTemplate = SpringUtil.getBean(RestTemplate.class); + if (null != proxy) { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + requestFactory.setProxy(proxy); + restTemplate.setRequestFactory(requestFactory); + } + return restTemplate; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchService.java b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchService.java new file mode 100644 index 0000000..4c67691 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/interfaces/AbstractSearchService.java @@ -0,0 +1,11 @@ +package com.moyz.adi.common.interfaces; + +public abstract class AbstractSearchService { + + public abstract boolean isEnabled(); + + public abstract void briefSearch(String question); + + public abstract void detailSearch(String question); + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/mapper/AiSearchRecordMapper.java b/adi-common/src/main/java/com/moyz/adi/common/mapper/AiSearchRecordMapper.java new file mode 100644 index 0000000..b1dba1d --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/mapper/AiSearchRecordMapper.java @@ -0,0 +1,9 @@ +package com.moyz.adi.common.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.moyz.adi.common.entity.AiSearchRecord; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface AiSearchRecordMapper extends BaseMapper { +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/searchengine/GoogleSearchEngine.java b/adi-common/src/main/java/com/moyz/adi/common/searchengine/GoogleSearchEngine.java new file mode 100644 index 0000000..ddfbd21 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/searchengine/GoogleSearchEngine.java @@ -0,0 +1,56 @@ +package com.moyz.adi.common.searchengine; + +import com.moyz.adi.common.cosntant.AdiConstant; +import com.moyz.adi.common.dto.GoogleSearchResp; +import com.moyz.adi.common.dto.SearchResult; +import com.moyz.adi.common.dto.SearchResultItem; +import com.moyz.adi.common.interfaces.AbstractSearchEngine; +import com.moyz.adi.common.util.MPPageUtil; +import com.moyz.adi.common.vo.GoogleSetting; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; + +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.List; + +import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.GOOGLE_SETTING; + +@Slf4j +public class GoogleSearchEngine extends AbstractSearchEngine { + + public GoogleSearchEngine() { + super(AdiConstant.SearchEngineName.GOOGLE, GOOGLE_SETTING, GoogleSetting.class); + } + + @Override + public boolean isEnabled() { + return StringUtils.isNoneBlank(setting.getKey(), setting.getCx()); + } + + @Override + public SearchResult search(String searchTxt) { + SearchResult result = new SearchResult(); + List items = new ArrayList<>(); + try { + ResponseEntity resp = getRestTemplate().getForEntity(MessageFormat.format("{0}?key={1}&cx={2}&q={3}", setting.getUrl(), setting.getKey(), setting.getCx(), searchTxt), GoogleSearchResp.class); + if (null != resp && HttpStatus.OK.isSameCodeAs(resp.getStatusCode())) { + GoogleSearchResp googleSearchResp = resp.getBody(); + if (null != googleSearchResp.getError()) { + log.error("google search error,code:{},message:{}", googleSearchResp.getError().getCode(), googleSearchResp.getError().getMessage()); + result.setErrorMessage(googleSearchResp.getError().getMessage()); + } else { + log.info("google response:{}", resp); + items = MPPageUtil.convertTo(googleSearchResp.getItems(), SearchResultItem.class); + } + } + } catch (Exception e) { + log.error("google search error", e); + } + result.setItems(items); + return result; + } + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/searchengine/SearchEngineContext.java b/adi-common/src/main/java/com/moyz/adi/common/searchengine/SearchEngineContext.java new file mode 100644 index 0000000..914a14c --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/searchengine/SearchEngineContext.java @@ -0,0 +1,40 @@ +package com.moyz.adi.common.searchengine; + +import com.moyz.adi.common.cosntant.AdiConstant; +import com.moyz.adi.common.interfaces.AbstractSearchEngine; +import com.moyz.adi.common.vo.SearchEngineInfo; +import lombok.extern.slf4j.Slf4j; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Search engine context. strategy design model + */ +@Slf4j +public class SearchEngineContext { + public static final Map NAME_TO_ENGINE = new LinkedHashMap<>(); + + private AbstractSearchEngine iSearchEngine; + + public SearchEngineContext(String searchEngineName) { + if (null == NAME_TO_ENGINE.get(searchEngineName)) { + log.warn("︿︿︿ Can not find {}, use the default engine GOOGLE ︿︿︿", searchEngineName); + iSearchEngine = NAME_TO_ENGINE.get(AdiConstant.SearchEngineName.GOOGLE).getEngine(); + } else { + iSearchEngine = NAME_TO_ENGINE.get(searchEngineName).getEngine(); + } + } + + public static void addEngine(String engineName, AbstractSearchEngine searchEngine) { + SearchEngineInfo info = new SearchEngineInfo(); + info.setName(engineName); + info.setEnable(searchEngine.isEnabled()); + info.setEngine(searchEngine); + NAME_TO_ENGINE.put(engineName, info); + } + + public AbstractSearchEngine getEngine() { + return iSearchEngine; + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/AiSearchRecordService.java b/adi-common/src/main/java/com/moyz/adi/common/service/AiSearchRecordService.java new file mode 100644 index 0000000..a83caf3 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/service/AiSearchRecordService.java @@ -0,0 +1,82 @@ +package com.moyz.adi.common.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; +import com.moyz.adi.common.base.ThreadContext; +import com.moyz.adi.common.dto.AiSearchRecordResp; +import com.moyz.adi.common.dto.AiSearchResp; +import com.moyz.adi.common.dto.SearchEngineResp; +import com.moyz.adi.common.entity.AiSearchRecord; +import com.moyz.adi.common.exception.BaseException; +import com.moyz.adi.common.mapper.AiSearchRecordMapper; +import com.moyz.adi.common.util.BizPager; +import com.moyz.adi.common.util.MPPageUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; + +import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND; + +/** + * Ai search + */ +@Slf4j +@Service +public class AiSearchRecordService extends ServiceImpl { + + /** + * List search records + * + * @param maxId Anchor id + * @param keyword user's question + * @return + */ + public AiSearchResp listByMaxId(Long maxId, String keyword) { + LambdaQueryWrapper wrapper = new LambdaQueryWrapper<>(); + wrapper.eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId()); + wrapper.eq(AiSearchRecord::getIsDeleted, false); + if (StringUtils.isNotBlank(keyword)) { + wrapper.like(AiSearchRecord::getQuestion, keyword); + } + AiSearchResp result = new AiSearchResp(); + BizPager.listByMaxId(maxId, wrapper, this, AiSearchRecord::getId, (recordList, minId) -> { + List list = MPPageUtil.convertTo(recordList, AiSearchRecordResp.class); + list.forEach(item -> { + if(null == item.getSearchEngineResp()){ + SearchEngineResp searchEngineResp = new SearchEngineResp(); + searchEngineResp.setItems(new ArrayList<>()); + item.setSearchEngineResp(searchEngineResp); + } + }); + result.setRecords(list); + result.setMinId(minId); + }); + return result; + } + + public boolean softDelete(String uuid) { + if (ThreadContext.getCurrentUser().getIsAdmin()) { + return ChainWrappers.lambdaUpdateChain(baseMapper) + .eq(AiSearchRecord::getUuid, uuid) + .set(AiSearchRecord::getIsDeleted, true) + .update(); + } + AiSearchRecord exist = ChainWrappers.lambdaQueryChain(baseMapper) + .eq(AiSearchRecord::getUuid, uuid) + .eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId()) + .one(); + if (null == exist) { + throw new BaseException(A_DATA_NOT_FOUND); + } + return ChainWrappers.lambdaUpdateChain(baseMapper) + .eq(AiSearchRecord::getId, exist.getId()) + .set(AiSearchRecord::getIsDeleted, true) + .update(); + } + +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java index c2eb480..28c1278 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/ConversationMessageService.java @@ -8,7 +8,6 @@ import com.moyz.adi.common.dto.AskReq; import com.moyz.adi.common.entity.Conversation; import com.moyz.adi.common.entity.ConversationMessage; import com.moyz.adi.common.entity.User; -import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.exception.BaseException; @@ -16,8 +15,6 @@ import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.SSEEmitterHelper; import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.util.LocalCache; -import com.moyz.adi.common.util.LocalDateTimeUtil; -import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.SseAskParams; @@ -66,11 +63,16 @@ public class ConversationMessageService extends ServiceImpl= convsMax) { - sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); + sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); return false; } //check 3: current user's quota ErrorEnum errorMsg = quotaHelper.checkTextQuota(user); if (null != errorMsg) { - sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo()); + sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, errorMsg.getInfo()); return false; } } catch (Exception e) { @@ -112,7 +114,7 @@ public class ConversationMessageService extends ServiceImpl { + sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> { _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); }); } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java index 4f33770..1702e5d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/DashScopeLLMService.java @@ -22,7 +22,7 @@ import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET; public class DashScopeLLMService extends AbstractLLMService { public DashScopeLLMService(String modelName) { - super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class, null); + super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class); } @Override diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java index 9226ad0..8fad37a 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/Initializer.java @@ -3,6 +3,8 @@ package com.moyz.adi.common.service; import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.helper.ImageModelContext; import com.moyz.adi.common.helper.LLMContext; +import com.moyz.adi.common.searchengine.GoogleSearchEngine; +import com.moyz.adi.common.searchengine.SearchEngineContext; import dev.langchain4j.model.openai.OpenAiModelName; import jakarta.annotation.PostConstruct; import jakarta.annotation.Resource; @@ -43,16 +45,16 @@ public class Initializer { //openai String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING); - if(openaiModels.length == 0){ + if (openaiModels.length == 0) { log.warn("openai service is disabled"); } for (String model : openaiModels) { - LLMContext.addLLMService(model, new OpenAiLLMService(model, proxy)); + LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy)); } //dashscope String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING); - if(dashscopeModels.length == 0){ + if (dashscopeModels.length == 0) { log.warn("dashscope service is disabled"); } for (String model : dashscopeModels) { @@ -61,7 +63,7 @@ public class Initializer { //qianfan String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING); - if(qianfanModels.length == 0){ + if (qianfanModels.length == 0) { log.warn("qianfan service is disabled"); } for (String model : qianfanModels) { @@ -70,14 +72,18 @@ public class Initializer { //ollama String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING); - if(ollamaModels.length == 0){ + if (ollamaModels.length == 0) { log.warn("ollama service is disabled"); } for (String model : ollamaModels) { LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model)); } - ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2, proxy)); + ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy)); + + + //search engine + SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy)); ragService.init(); diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseItemService.java b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseItemService.java index 85db6ba..d8ad0e2 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseItemService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/KnowledgeBaseItemService.java @@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.moyz.adi.common.base.ThreadContext; +import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.dto.KbItemEditReq; import com.moyz.adi.common.entity.KnowledgeBase; import com.moyz.adi.common.entity.KnowledgeBaseItem; @@ -24,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional; import java.util.Optional; import java.util.UUID; +import static com.moyz.adi.common.cosntant.AdiConstant.RAG_TYPE_KB; import static com.moyz.adi.common.enums.ErrorEnum.*; @Slf4j @@ -112,8 +114,8 @@ public class KnowledgeBaseItemService extends ServiceImpl wrapper = new LambdaQueryWrapper(); wrapper.eq(KnowledgeBase::getIsPublic, true); wrapper.eq(KnowledgeBase::getIsDeleted, false); - if(StringUtils.isNotBlank(keyword)){ + if (StringUtils.isNotBlank(keyword)) { wrapper.like(KnowledgeBase::getTitle, keyword); } wrapper.orderByDesc(KnowledgeBase::getStarCount, KnowledgeBase::getUpdateTime); @@ -223,7 +225,8 @@ public class KnowledgeBaseService extends ServiceImpl> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName); + Map metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid); + Pair> responsePair = ragService.retrieveAndAsk(metadataCond, question, modelName); Response ar = responsePair.getRight(); int inputTokenCount = ar.tokenUsage().inputTokenCount(); @@ -235,7 +238,12 @@ public class KnowledgeBaseService extends ServiceImpl= Integer.parseInt(askQuota)) { throw new BaseException(A_QA_ASK_LIMIT); } @@ -265,10 +273,10 @@ public class KnowledgeBaseService extends ServiceImpl metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid); + Prompt prompt = ragService.retrieveAndCreatePrompt(metadataCond, req.getQuestion()); + if (null == prompt) { + sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo()); return; } String promptText = prompt.text(); @@ -277,7 +285,7 @@ public class KnowledgeBaseService extends ServiceImpl { + sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> { knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens()); userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens()); }); diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/OllamaLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/OllamaLLMService.java index ade98cc..73ec58d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/OllamaLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/OllamaLLMService.java @@ -7,14 +7,13 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.ollama.OllamaChatModel; import dev.langchain4j.model.ollama.OllamaStreamingChatModel; import org.apache.commons.lang3.StringUtils; -import org.springframework.stereotype.Service; import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTING; public class OllamaLLMService extends AbstractLLMService { public OllamaLLMService(String modelName) { - super(modelName, OLLAMA_SETTING, OllamaSetting.class, null); + super(modelName, OLLAMA_SETTING, OllamaSetting.class); } @Override diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiImageModelService.java b/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiImageModelService.java index b777b66..8f97e53 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiImageModelService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/OpenAiImageModelService.java @@ -45,8 +45,8 @@ public class OpenAiImageModelService extends AbstractImageModelService { - public OpenAiLLMService(String modelName, Proxy proxy) { - super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy); + public OpenAiLLMService(String modelName) { + super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class); } @Override diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java b/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java index 16f94b1..a8c0f0d 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/QianFanLLMService.java @@ -19,7 +19,7 @@ import org.apache.commons.lang3.StringUtils; public class QianFanLLMService extends AbstractLLMService { public QianFanLLMService(String modelName) { - super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class, null); + super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class); } @Override diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java index af31f51..e65eb80 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java +++ b/adi-common/src/main/java/com/moyz/adi/common/service/RAGService.java @@ -1,10 +1,7 @@ package com.moyz.adi.common.service; import com.moyz.adi.common.helper.LLMContext; -import com.moyz.adi.common.interfaces.TriConsumer; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; -import com.moyz.adi.common.vo.AnswerMeta; -import com.moyz.adi.common.vo.PromptMeta; import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.splitter.DocumentSplitters; @@ -14,7 +11,6 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.output.Response; import dev.langchain4j.store.embedding.EmbeddingMatch; @@ -24,35 +20,37 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.Triple; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Service; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static com.moyz.adi.common.cosntant.AdiConstant.PROMPT_TEMPLATE; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static java.util.stream.Collectors.joining; @Slf4j -@Service public class RAGService { - @Value("${spring.datasource.url}") private String dataBaseUrl; - @Value("${spring.datasource.username}") private String dataBaseUserName; - @Value("${spring.datasource.password}") private String dataBasePassword; - private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}"); + + private String tableName; private EmbeddingModel embeddingModel; private EmbeddingStore embeddingStore; + public RAGService(String tableName, String dataBaseUrl, String dataBaseUserName, String dataBasePassword) { + this.tableName = tableName; + this.dataBasePassword = dataBasePassword; + this.dataBaseUserName = dataBaseUserName; + this.dataBaseUrl = dataBaseUrl; + } + public void init() { log.info("initEmbeddingModel"); embeddingModel = new AllMiniLmL6V2EmbeddingModel(); @@ -88,7 +86,7 @@ public class RAGService { .dimension(384) .createTable(true) .dropTableFirst(false) - .table("adi_knowledge_base_embedding") + .table(tableName) .build(); return embeddingStore; } @@ -112,7 +110,14 @@ public class RAGService { getEmbeddingStoreIngestor().ingest(document); } - public Prompt retrieveAndCreatePrompt(String kbUuid, String question) { + /** + * Retrieve documents and create prompt + * + * @param metadataCond Query condition + * @param question User's question + * @return Document in the vector db + */ + public Prompt retrieveAndCreatePrompt(Map metadataCond, String question) { // Embed the question Embedding questionEmbedding = embeddingModel.embed(question).content(); @@ -120,7 +125,7 @@ public class RAGService { // You can play with parameters below to find a sweet spot for your specific use case int maxResults = 3; double minScore = 0.6; - List> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByKbUuid(kbUuid, questionEmbedding, maxResults, minScore); + List> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByMetadata(metadataCond, questionEmbedding, maxResults, minScore); // Create a prompt for the model that includes question and relevant embeddings String information = relevantEmbeddings.stream() @@ -130,23 +135,28 @@ public class RAGService { if (StringUtils.isBlank(information)) { return null; } - return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); + return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); } /** * 召回并提问 * - * @param kbUuid 知识库uuid - * @param question 用户的问题 - * @param modelName LLM model name + * @param metadataCond metadata condition + * @param question user's question + * @param modelName LLM model name * @return */ - public Pair> retrieveAndAsk(String kbUuid, String question, String modelName) { - Prompt prompt = retrieveAndCreatePrompt(kbUuid, question); + public Pair> retrieveAndAsk(Map metadataCond, String question, String modelName) { + + Prompt prompt = retrieveAndCreatePrompt(metadataCond, question); if (null == prompt) { return null; } Response response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); return new ImmutablePair<>(prompt.text(), response); } + + public static final String parsePromptTemplate(String question, String information) { + return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))).text(); + } } diff --git a/adi-common/src/main/java/com/moyz/adi/common/service/SearchService.java b/adi-common/src/main/java/com/moyz/adi/common/service/SearchService.java new file mode 100644 index 0000000..f3be1d1 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/service/SearchService.java @@ -0,0 +1,250 @@ +package com.moyz.adi.common.service; + +import com.google.common.collect.ImmutableMap; +import com.moyz.adi.common.base.ThreadContext; +import com.moyz.adi.common.cosntant.AdiConstant; +import com.moyz.adi.common.dto.SearchEngineResp; +import com.moyz.adi.common.dto.SearchResult; +import com.moyz.adi.common.dto.SearchResultItem; +import com.moyz.adi.common.entity.AiSearchRecord; +import com.moyz.adi.common.entity.User; +import com.moyz.adi.common.helper.SSEEmitterHelper; +import com.moyz.adi.common.searchengine.SearchEngineContext; +import com.moyz.adi.common.vo.SseAskParams; +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.model.input.Prompt; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.jsoup.Jsoup; +import org.jsoup.safety.Cleaner; +import org.jsoup.safety.Safelist; +import org.springframework.context.annotation.Lazy; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; + +import static com.moyz.adi.common.enums.ErrorEnum.B_NO_ANSWER; + +/** + * RAG search + */ +@Slf4j +@Service +public class SearchService { + + @Lazy + @Resource + private SearchService _this; + + @Resource + private RAGService searchRagService; + + @Resource + private SSEEmitterHelper sseEmitterHelper; + + @Resource + private AiSearchRecordService aiSearchRecordService; + + @Resource + private AsyncTaskExecutor mainExecutor; + + public SseEmitter search(boolean isBriefSearch, String searchText, String engineName, String modelName) { + User user = ThreadContext.getCurrentUser(); + SseEmitter sseEmitter = new SseEmitter(); + if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) { + return sseEmitter; + } + sseEmitterHelper.startSse(user, sseEmitter); + _this.asyncSearch(user, sseEmitter, isBriefSearch, searchText, engineName, modelName); + return sseEmitter; + } + + @Async + public void asyncSearch(User user, SseEmitter sseEmitter, boolean isBriefSearch, String searchText, String engineName, String modelName) { + SearchResult searchResult = new SearchEngineContext(engineName).getEngine().search(searchText); + if (StringUtils.isNotBlank(searchResult.getErrorMessage())) { + sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, searchResult.getErrorMessage()); + return; + } + if (CollectionUtils.isEmpty(searchResult.getItems())) { + sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo()); + return; + } + boolean sendFail = false; + try { + sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.AI_SEARCH_SOURCE_LINKS).data(searchResult.getItems())); + } catch (IOException e) { + sendFail = true; + log.error("asyncSearch error", e); + sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, e.getMessage()); + } + if (sendFail) { + return; + } + if (isBriefSearch) { + briefSearch(user, searchText, modelName, searchResult.getItems(), sseEmitter); + } else { + detailSearch(user, searchText, engineName, modelName, searchResult.getItems(), sseEmitter); + } + } + + /** + * 1.Search by search engine + * 2.Create prompt by search response + * 3.Send prompt to llm + * + * @param user + * @param searchText + * @param modelName + * @param resultItems + * @param sseEmitter + */ + public void briefSearch(User user, String searchText, String modelName, List resultItems, SseEmitter sseEmitter) { + log.info("briefSearch,searchText:{}", searchText); + StringBuilder builder = new StringBuilder(); + for (SearchResultItem item : resultItems) { + builder.append(item.getSnippet()).append("\n\n"); + } + String ragQuestion = builder.toString(); + String prompt = RAGService.parsePromptTemplate(searchText, ragQuestion); + + SearchEngineResp resp = new SearchEngineResp().setItems(resultItems); + + SseAskParams sseAskParams = new SseAskParams(); + sseAskParams.setSystemMessage(StringUtils.EMPTY); + sseAskParams.setSseEmitter(sseEmitter); + sseAskParams.setUserMessage(prompt); + sseAskParams.setModelName(modelName); + sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> { + AiSearchRecord newRecord = new AiSearchRecord(); + newRecord.setUuid(UUID.randomUUID().toString().replace("-", "")); + newRecord.setQuestion(searchText); + newRecord.setSearchEngineResp(resp); + newRecord.setPrompt(prompt); + newRecord.setPromptTokens(promptMeta.getTokens()); + newRecord.setAnswer(response); + newRecord.setAnswerTokens(answerMeta.getTokens()); + newRecord.setUserUuid(user.getUuid()); + newRecord.setUserId(user.getId()); + aiSearchRecordService.save(newRecord); + }); + } + + /** + * 1.Search by search engine + * 2.Save the response to pgvector + * 3.Retrieve document and create prompt + * 4.Send prompt to llm + * + * @param user + * @param searchText + * @param engineName + * @param modelName + * @param resultItems + * @param sseEmitter + */ + public void detailSearch(User user, String searchText, String engineName, String modelName, List resultItems, SseEmitter sseEmitter) { + log.info("detailSearch,searchText:{}", searchText); + //Save to DB + SearchEngineResp resp = new SearchEngineResp().setItems(resultItems); + AiSearchRecord newRecord = new AiSearchRecord(); + String searchUuid = UUID.randomUUID().toString().replace("-", ""); + newRecord.setUuid(searchUuid); + newRecord.setQuestion(searchText); + newRecord.setSearchEngineResp(resp); + newRecord.setUserId(user.getId()); + newRecord.setUserUuid(user.getUuid()); + aiSearchRecordService.save(newRecord); + + CountDownLatch countDownLatch = new CountDownLatch(resultItems.size()); + for (int i = 0; i < resultItems.size(); i++) { + int finalI = i; + mainExecutor.execute(() -> { + try { + SearchResultItem item = resultItems.get(finalI); + String content; + if (finalI < 2) { + content = getContentFromRemote(item); + + //Fill content with html body text + item.setContent(content); + } else { + content = item.getSnippet(); + } + + //embedding + Metadata metadata = new Metadata(); + metadata.add(AdiConstant.EmbeddingMetadataKey.ENGINE_NAME, engineName); + metadata.add(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid); + Document document = new Document(content, metadata); + searchRagService.ingest(document); + + } catch (Exception e) { + log.error("Detail search error,uuid:{}", searchUuid, e); + } finally { + countDownLatch.countDown(); + } + }); + } + try { + countDownLatch.await(); + } catch (InterruptedException e) { + log.error("CountDownLatch await error,uuid:{}", searchUuid, e); + throw new RuntimeException(e); + } + + log.info("Create prompt"); + Prompt prompt = searchRagService.retrieveAndCreatePrompt(ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid), searchText); + + SseAskParams sseAskParams = new SseAskParams(); + sseAskParams.setSystemMessage(StringUtils.EMPTY); + sseAskParams.setSseEmitter(sseEmitter); + sseAskParams.setUserMessage(prompt.text()); + sseAskParams.setModelName(modelName); + + log.info("Push to model"); + sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> { + + AiSearchRecord existRecord = aiSearchRecordService.lambdaQuery().eq(AiSearchRecord::getUuid, searchUuid).one(); + + AiSearchRecord updateRecord = new AiSearchRecord(); + updateRecord.setId(existRecord.getId()); + //Update search engine response content.(with html body text) + updateRecord.setSearchEngineResp(new SearchEngineResp().setItems(resultItems)); + updateRecord.setPrompt(prompt.text()); + updateRecord.setPromptTokens(promptMeta.getTokens()); + updateRecord.setAnswer(response); + updateRecord.setAnswerTokens(answerMeta.getTokens()); + aiSearchRecordService.updateById(updateRecord); + }); + } + + private String getContentFromRemote(SearchResultItem item) { + String result = ""; + try { + org.jsoup.nodes.Document doc = Jsoup.connect(item.getLink()).ignoreContentType(true).get(); + if (doc.getElementsByTag("main").size() > 0) { + result = doc.getElementsByTag("main").get(0).html(); + } else { + result = doc.body().html(); + } + if (StringUtils.isBlank(result)) { + log.error("Empty content from {}, use snippet instead", item.getLink()); + return item.getSnippet(); + } + } catch (Exception e) { + log.error("Failed to load document from {}, use snippet instead", item.getLink(), e); + } + Cleaner cleaner = new Cleaner(Safelist.none()); + return cleaner.clean(Jsoup.parse(result)).text(); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/util/AdiPgVectorEmbeddingStore.java b/adi-common/src/main/java/com/moyz/adi/common/util/AdiPgVectorEmbeddingStore.java index 1d20bfe..7083051 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/util/AdiPgVectorEmbeddingStore.java +++ b/adi-common/src/main/java/com/moyz/adi/common/util/AdiPgVectorEmbeddingStore.java @@ -9,6 +9,7 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import lombok.Builder; +import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; import org.slf4j.Logger; @@ -286,16 +287,24 @@ public class AdiPgVectorEmbeddingStore implements EmbeddingStore { } } - //adi - public List> findRelevantByKbUuid(String kbUuid, Embedding referenceEmbedding, int maxResults, double minScore) { + public List> findRelevantByMetadata(Map metadatCondition, Embedding referenceEmbedding, int maxResults, double minScore) { List> result = new ArrayList<>(); try (Connection connection = setupConnection()) { String referenceVector = Arrays.toString(referenceEmbedding.vector()); - //新增查询条件kb_id + //deal with metadata condition + StringBuilder whereSql = new StringBuilder(); + if (null != metadatCondition && !metadatCondition.isEmpty()) { + whereSql = new StringBuilder("where"); + for (String key : metadatCondition.keySet()) { + whereSql.append(" metadata->>'").append(key).append("' = '").append(metadatCondition.get(key)).append("' and"); + } + whereSql.replace(whereSql.length() - 3, whereSql.length(), ""); + } String query = String.format( - "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s where metadata->>'kb_uuid' = '%s') SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", - referenceVector, table, kbUuid, minScore, maxResults); + "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s " + whereSql + ") SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", + referenceVector, table, minScore, maxResults); + log.info(query); PreparedStatement selectStmt = connection.prepareStatement(query); ResultSet resultSet = selectStmt.executeQuery(); diff --git a/adi-common/src/main/java/com/moyz/adi/common/util/BizPager.java b/adi-common/src/main/java/com/moyz/adi/common/util/BizPager.java index df9ec63..8340196 100644 --- a/adi-common/src/main/java/com/moyz/adi/common/util/BizPager.java +++ b/adi-common/src/main/java/com/moyz/adi/common/util/BizPager.java @@ -10,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils; import java.util.List; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.ObjLongConsumer; import java.util.function.Supplier; public class BizPager { @@ -68,6 +69,21 @@ public class BizPager { } while (!Thread.currentThread().isInterrupted() && records.size() == AdiConstant.DEFAULT_PAGE_SIZE); } + public static void listByMaxId(Long maxId, LambdaQueryWrapper queryWrapper, IService service, SFunction idSupplier, ObjLongConsumer> consumer) { + if (maxId > 0) { + queryWrapper.lt(idSupplier, maxId); + } + queryWrapper.orderByDesc(idSupplier); + queryWrapper.last("limit " + AdiConstant.DEFAULT_PAGE_SIZE); + + long minId = 0; + List records = service.list(queryWrapper); + if (CollectionUtils.isNotEmpty(records)) { + minId = records.stream().map(idSupplier).reduce(Long::min).get(); + } + consumer.accept(records, minId); + } + /** * 以Long类型的惟一字段(通常为id)为锚点,按页获取数据 *
不依赖mybatis-plus diff --git a/adi-common/src/main/java/com/moyz/adi/common/util/SpringUtil.java b/adi-common/src/main/java/com/moyz/adi/common/util/SpringUtil.java new file mode 100644 index 0000000..2d5400f --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/util/SpringUtil.java @@ -0,0 +1,25 @@ +package com.moyz.adi.common.util; + +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.stereotype.Component; + +@Component +public class SpringUtil implements ApplicationContextAware { + + private static ApplicationContext applicationContext; + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + SpringUtil.applicationContext = applicationContext; + } + + public static T getBean(String name) { + return (T) applicationContext.getBean(name); + } + + public static T getBean(Class clazz) { + return applicationContext.getBean(clazz); + } +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/GoogleSetting.java b/adi-common/src/main/java/com/moyz/adi/common/vo/GoogleSetting.java new file mode 100644 index 0000000..467c9aa --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/GoogleSetting.java @@ -0,0 +1,10 @@ +package com.moyz.adi.common.vo; + +import lombok.Data; + +@Data +public class GoogleSetting { + private String url; + private String key; + private String cx; +} diff --git a/adi-common/src/main/java/com/moyz/adi/common/vo/SearchEngineInfo.java b/adi-common/src/main/java/com/moyz/adi/common/vo/SearchEngineInfo.java new file mode 100644 index 0000000..00cfff1 --- /dev/null +++ b/adi-common/src/main/java/com/moyz/adi/common/vo/SearchEngineInfo.java @@ -0,0 +1,13 @@ +package com.moyz.adi.common.vo; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.moyz.adi.common.interfaces.AbstractSearchEngine; +import lombok.Data; + +@Data +public class SearchEngineInfo { + private String name; + private Boolean enable; + @JsonIgnore + private AbstractSearchEngine engine; +} diff --git a/docs/create.sql b/docs/create.sql index 6d867a8..4a5ef88 100644 --- a/docs/create.sql +++ b/docs/create.sql @@ -390,6 +390,9 @@ VALUES ('qianfan_setting', '{"api_key":"","secret_key":"","models":[]}'); INSERT INTO adi_sys_config (name, value) VALUES ('ollama_setting', '{"base_url":"","models":[]}'); INSERT INTO adi_sys_config (name, value) +VALUES ('google_setting', + '{"url":"https://www.googleapis.com/customsearch/v1","key":"","cx":""}'); +INSERT INTO adi_sys_config (name, value) VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}'); INSERT INTO adi_sys_config (name, value) VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}'); @@ -544,3 +547,48 @@ create trigger trigger_kb_qa_record_update_time on adi_knowledge_base_qa_record for each row execute procedure update_modified_column(); + +-- ai search +create table adi_ai_search_record +( + id bigserial primary key, + uuid varchar(32) default ''::character varying not null, + question varchar(1000) default ''::character varying not null, + search_engine_response jsonb not null, + prompt text default ''::character varying not null, + prompt_tokens integer DEFAULT 0 NOT NULL, + answer text default ''::character varying not null, + answer_tokens integer DEFAULT 0 NOT NULL, + user_id bigint default '0' NOT NULL, + user_uuid varchar(32) default ''::character varying not null, + create_time timestamp default CURRENT_TIMESTAMP not null, + update_time timestamp default CURRENT_TIMESTAMP not null, + is_deleted boolean default false not null +); +comment on table adi_ai_search_record is 'Search record'; + +comment on column adi_ai_search_record.question is 'User original question'; + +comment on column adi_ai_search_record.search_engine_response is 'Search engine''s response content'; + +comment on column adi_ai_search_record.prompt is 'Prompt of LLM'; + +comment on column adi_ai_search_record.prompt_tokens is 'prompt消耗的token数量'; + +comment on column adi_ai_search_record.answer is 'LLM response'; + +comment on column adi_ai_search_record.answer_tokens is 'LLM响应消耗的token数量'; + +comment on column adi_ai_search_record.user_id is 'Id from adi_user'; + +comment on column adi_ai_search_record.create_time is '创建时间'; + +comment on column adi_ai_search_record.update_time is '更新时间'; + +comment on column adi_ai_search_record.is_deleted is '0: Normal; 1: Deleted'; + +create trigger trigger_ai_search_record + before update + on adi_ai_search_record + for each row +execute procedure update_modified_column(); \ No newline at end of file