add ai search

This commit is contained in:
moyangzhan 2024-04-08 00:17:23 +08:00
parent abd9ced7cc
commit 86a70e09ca
49 changed files with 1280 additions and 106 deletions

View File

@ -17,7 +17,9 @@
* 提示词 * 提示词
* 额度控制 * 额度控制
* 基于大模型的知识库RAG * 基于大模型的知识库RAG
* 基于大模型的搜索RAG
* 多模型随意切换 * 多模型随意切换
* 多搜索引擎随意切换
## 接入的模型: ## 接入的模型:
@ -27,6 +29,14 @@
* ollama * ollama
* DALL-E 2 * DALL-E 2
## 接入的搜索引擎
Google
Bing (TODO)
百度 (TODO)
## 技术栈 ## 技术栈
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web) 该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
@ -53,7 +63,7 @@ vue3+typescript+pnpm
* 创建数据库aideepin * 创建数据库aideepin
* 执行docs/create.sql * 执行docs/create.sql
* 填充各模型的配置 * 填充各模型的配置(至少设置一个)
openai的secretKey openai的secretKey
@ -79,6 +89,15 @@ ollama的配置
update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting'; update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting';
``` ```
* 填充搜索引擎的配置
Google的配置
```
update adi_sys_config set value = '{"url":"https://www.googleapis.com/customsearch/v1","key":"my key from cloud.google.com","cx":"my cx from programmablesearchengine.google.com"}' where name = 'google_setting';
```
**b. 修改配置文件** **b. 修改配置文件**
* postgresql: application-[dev|prod].xml中的spring.datasource * postgresql: application-[dev|prod].xml中的spring.datasource
@ -122,23 +141,30 @@ docker run -d \
## 待办: ## 待办:
* AI搜索 增强RAG
* 增强RAG
增加搜索引擎BING、百度
## 截图 ## 截图
**AI聊天** **AI聊天**
![1691583184761](image/README/1691583184761.png) ![1691583184761](image/README/1691583184761.png)
![1691583124744](image/README/1691583124744.png "AI绘图") **AI画图**
![1691583329105](image/README/1691583329105.png "token统计") ![1691583124744](image/README/1691583124744.png "AI绘图")
**知识库:** **知识库:**
![kbindex](image/README/kbidx.png) ![kbindex](image/README/kbidx.png)
![kb01](image/README/kb01.png) ![kb01](image/README/kb01.png)
**向量化:**
![kb02](image/README/kb02.png) ![kb02](image/README/kb02.png)
![kb03](image/README/kb03.png) ![kb03](image/README/kb03.png)
**额度统计:**
![1691583329105](https://file+.vscode-resource.vscode-cdn.net/e%3A/WORKSPACE/aideepin/image/README/1691583329105.png "token统计")

View File

@ -11,6 +11,9 @@ spring:
name: AiDeepIn name: AiDeepIn
profiles: profiles:
active: dev active: dev
mvc:
async:
request-timeout: 60000
jackson: jackson:
date-format: "yyyy-MM-dd HH:mm:ss" date-format: "yyyy-MM-dd HH:mm:ss"
time-zone: "GMT+8" time-zone: "GMT+8"

View File

@ -3,7 +3,9 @@ package com.moyz.adi.chat.controller;
import com.moyz.adi.common.dto.LoginReq; import com.moyz.adi.common.dto.LoginReq;
import com.moyz.adi.common.dto.LoginResp; import com.moyz.adi.common.dto.LoginResp;
import com.moyz.adi.common.dto.RegisterReq; import com.moyz.adi.common.dto.RegisterReq;
import com.moyz.adi.common.searchengine.SearchEngineContext;
import com.moyz.adi.common.service.UserService; import com.moyz.adi.common.service.UserService;
import com.moyz.adi.common.vo.SearchEngineInfo;
import com.ramostear.captcha.HappyCaptcha; import com.ramostear.captcha.HappyCaptcha;
import com.ramostear.captcha.support.CaptchaType; import com.ramostear.captcha.support.CaptchaType;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -23,6 +25,8 @@ import org.springframework.web.bind.annotation.*;
import java.io.IOException; import java.io.IOException;
import java.net.URLEncoder; import java.net.URLEncoder;
import java.util.List;
import java.util.stream.Collectors;
import static org.springframework.http.HttpHeaders.AUTHORIZATION; import static org.springframework.http.HttpHeaders.AUTHORIZATION;
@ -123,4 +127,9 @@ public class AuthController {
happyCaptcha.output(); happyCaptcha.output();
} }
@Operation(summary = "Search engine list")
@GetMapping(value = "/search-engine/list")
public List<SearchEngineInfo> engines() {
return SearchEngineContext.NAME_TO_ENGINE.values().stream().collect(Collectors.toList());
}
} }

View File

@ -70,7 +70,7 @@ public class KnowledgeBaseController {
* *
* @return * @return
*/ */
@PostMapping("/star/{uuid}") @PostMapping("/star/{kbUuid}")
public boolean star(@PathVariable String kbUuid) { public boolean star(@PathVariable String kbUuid) {
return knowledgeBaseService.star(kbUuid); return knowledgeBaseService.star(kbUuid);
} }

View File

@ -44,6 +44,6 @@ public class KnowledgeBaseQAController {
@PostMapping("/record/del/{uuid}") @PostMapping("/record/del/{uuid}")
public boolean recordDel(@PathVariable String uuid) { public boolean recordDel(@PathVariable String uuid) {
return knowledgeBaseQaRecordService.softDelele(uuid); return knowledgeBaseQaRecordService.softDelete(uuid);
} }
} }

View File

@ -0,0 +1,29 @@
package com.moyz.adi.chat.controller;
import com.moyz.adi.common.dto.AiSearchReq;
import com.moyz.adi.common.service.SearchService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import org.springframework.http.MediaType;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Tag(name = "AI search controller")
@RequestMapping("/ai-search/")
@RestController
public class SearchController {
@Resource
private SearchService searchService;
@Operation(summary = "sse process")
@PostMapping(value = "/process", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter sseAsk(@RequestBody @Validated AiSearchReq req) {
return searchService.search(req.isBriefSearch(), req.getSearchText(), req.getEngineName(), req.getModelName());
}
}

View File

@ -0,0 +1,28 @@
package com.moyz.adi.chat.controller;
import com.moyz.adi.common.dto.AiSearchResp;
import com.moyz.adi.common.service.AiSearchRecordService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import org.springframework.web.bind.annotation.*;
@Tag(name = "Ai search record controller")
@RequestMapping("/ai-search-record/")
@RestController
public class SearchRecordController {
@Resource
private AiSearchRecordService aiSearchRecordService;
@Operation(summary = "List by max id")
@GetMapping(value = "/list")
public AiSearchResp list(@RequestParam(defaultValue = "0") Long maxId, String keyword) {
return aiSearchRecordService.listByMaxId(maxId, keyword);
}
@PostMapping("/del/{uuid}")
public boolean recordDel(@PathVariable String uuid) {
return aiSearchRecordService.softDelete(uuid);
}
}

View File

@ -0,0 +1,73 @@
package com.moyz.adi.common.base;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import org.postgresql.util.PGobject;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
@MappedTypes({JsonNode.class})
public class JsonNodeTypeHandler extends BaseTypeHandler<JsonNode> {
private static final ObjectMapper objectMapper = new ObjectMapper();
@Override
public void setNonNullParameter(PreparedStatement ps, int i, JsonNode parameter, JdbcType jdbcType)
throws SQLException {
PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
try {
jsonObject.setValue(parameter.toString());
ps.setObject(i, jsonObject);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public JsonNode getNullableResult(ResultSet rs, String columnName) throws SQLException {
String jsonSource = rs.getString(columnName);
if (jsonSource != null) {
try {
return objectMapper.readTree(jsonSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
@Override
public JsonNode getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
String jsonSource = rs.getString(columnIndex);
if (jsonSource != null) {
try {
return objectMapper.readTree(jsonSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
@Override
public JsonNode getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
String jsonSource = cs.getString(columnIndex);
if (jsonSource != null) {
try {
return objectMapper.readTree(jsonSource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
}

View File

@ -0,0 +1,70 @@
package com.moyz.adi.common.base;
import com.moyz.adi.common.dto.SearchEngineResp;
import com.moyz.adi.common.util.JsonUtil;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import org.postgresql.util.PGobject;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
@MappedTypes({SearchEngineResp.class})
public class SearchEngineRespTypeHandler extends BaseTypeHandler<SearchEngineResp> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i, SearchEngineResp parameter, JdbcType jdbcType) {
PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
try {
jsonObject.setValue(JsonUtil.toJson(parameter));
ps.setObject(i, jsonObject);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public SearchEngineResp getNullableResult(ResultSet rs, String columnName) throws SQLException {
String jsonSource = rs.getString(columnName);
if (jsonSource != null) {
try {
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
@Override
public SearchEngineResp getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
String jsonSource = rs.getString(columnIndex);
if (jsonSource != null) {
try {
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
@Override
public SearchEngineResp getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
String jsonSource = cs.getString(columnIndex);
if (jsonSource != null) {
try {
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return null;
}
}

View File

@ -10,11 +10,14 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.common.collect.Lists; import com.moyz.adi.common.base.SearchEngineRespTypeHandler;
import com.moyz.adi.common.dto.SearchEngineResp;
import com.moyz.adi.common.service.RAGService;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.pgvector.PGvector; import com.pgvector.PGvector;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.session.SqlSessionFactory; import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
@ -32,6 +35,15 @@ import javax.sql.DataSource;
@Configuration @Configuration
public class BeanConfig { public class BeanConfig {
@Value("${spring.datasource.url}")
private String dataBaseUrl;
@Value("${spring.datasource.username}")
private String dataBaseUserName;
@Value("${spring.datasource.password}")
private String dataBasePassword;
@Bean @Bean
public RestTemplate restTemplate() { public RestTemplate restTemplate() {
log.info("Configuration:create restTemplate"); log.info("Configuration:create restTemplate");
@ -42,7 +54,7 @@ public class BeanConfig {
requestFactory.setReadTimeout(60000); requestFactory.setReadTimeout(60000);
RestTemplate restTemplate = new RestTemplate(); RestTemplate restTemplate = new RestTemplate();
// 注册LOG拦截器 // 注册LOG拦截器
restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor())); // restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor()));
restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(requestFactory)); restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(requestFactory));
return restTemplate; return restTemplate;
@ -95,12 +107,28 @@ public class BeanConfig {
bean.setMapperLocations( bean.setMapperLocations(
new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml")); new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml"));
MybatisConfiguration configuration = bean.getConfiguration(); MybatisConfiguration configuration = bean.getConfiguration();
if(null == configuration){ if (null == configuration) {
configuration = new MybatisConfiguration(); configuration = new MybatisConfiguration();
bean.setConfiguration(configuration); bean.setConfiguration(configuration);
} }
bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class); bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class);
bean.getConfiguration().getTypeHandlerRegistry().register(SearchEngineResp.class, SearchEngineRespTypeHandler.class);
return bean.getObject(); return bean.getObject();
} }
@Bean(name = "kbRagService")
@Primary
public RAGService initKnowledgeBaseRAGService() {
RAGService ragService = new RAGService("adi_knowledge_base_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
ragService.init();
return ragService;
}
@Bean(name = "searchRagService")
public RAGService initSearchRAGService() {
RAGService ragService = new RAGService("adi_ai_search_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
ragService.init();
return ragService;
}
} }

View File

@ -1,10 +1,12 @@
package com.moyz.adi.common.cosntant; package com.moyz.adi.common.cosntant;
import dev.langchain4j.model.input.PromptTemplate;
import java.util.List; import java.util.List;
public class AdiConstant { public class AdiConstant {
public static final int DEFAULT_PAGE_SIZE = 1; public static final int DEFAULT_PAGE_SIZE = 10;
/** /**
* 验证码id过期时间1小时 * 验证码id过期时间1小时
@ -53,6 +55,13 @@ public class AdiConstant {
public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024"); public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024");
public static final PromptTemplate PROMPT_TEMPLATE = PromptTemplate.from("""
根据以下已知信息:
{{information}}
尽可能准确地回答用户的问题,以下是用户的问题:
{{question}}
注意,回答的内容不能让用户感知到已知信息的存在
""");
public static class GenerateImage { public static class GenerateImage {
public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1; public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1;
@ -64,11 +73,21 @@ public class AdiConstant {
public static final int STATUS_SUCCESS = 3; public static final int STATUS_SUCCESS = 3;
} }
public static class EmbeddingMetadataKey {
public static final String KB_UUID = "kb_uuid";
public static final String KB_ITEM_UUID = "kb_item_uuid";
public static final String ENGINE_NAME = "engine_name";
public static final String SEARCH_UUID = "search_uuid";
}
public static class SysConfigKey { public static class SysConfigKey {
public static final String OPENAI_SETTING = "openai_setting"; public static final String OPENAI_SETTING = "openai_setting";
public static final String DASHSCOPE_SETTING = "dashscope_setting"; public static final String DASHSCOPE_SETTING = "dashscope_setting";
public static final String QIANFAN_SETTING = "qianfan_setting"; public static final String QIANFAN_SETTING = "qianfan_setting";
public static final String OLLAMA_SETTING = "ollama_setting"; public static final String OLLAMA_SETTING = "ollama_setting";
public static final String GOOGLE_SETTING = "google_setting";
public static final String BING_SETTING = "bing_setting";
public static final String BAIDU_SETTING = "baidu_setting";
public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit"; public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit";
public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit"; public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit";
public static final String CONVERSATION_MAX_NUM = "conversation_max_num"; public static final String CONVERSATION_MAX_NUM = "conversation_max_num";
@ -78,8 +97,25 @@ public class AdiConstant {
public static final String QUOTA_BY_REQUEST_MONTHLY = "quota_by_request_monthly"; public static final String QUOTA_BY_REQUEST_MONTHLY = "quota_by_request_monthly";
public static final String QUOTA_BY_IMAGE_DAILY = "quota_by_image_daily"; public static final String QUOTA_BY_IMAGE_DAILY = "quota_by_image_daily";
public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly"; public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly";
public static final String QUOTA_BY_QA_ASK_DAILY = "quota_by_qa_ask_daily";
} }
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"}; public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
public static class SearchEngineName {
public static final String GOOGLE = "google";
public static final String BING = "bing";
public static final String BAIDU = "baidu";
}
public static class SSEEventName {
public static final String START = "[START]";
public static final String DONE = "[DONE]";
public static final String ERROR = "[ERROR]";
public static final String AI_SEARCH_SOURCE_LINKS = "[SOURCE_LINKS]";
}
public static final int RAG_TYPE_KB = 1;
public static final int RAG_TYPE_SEARCH = 2;
} }

View File

@ -0,0 +1,33 @@
package com.moyz.adi.common.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
@NoArgsConstructor
@AllArgsConstructor
@Builder
@Data
public class AiSearchRecordResp {
private String uuid;
private String question;
private SearchEngineResp searchEngineResp;
private String prompt;
private Integer promptTokens;
private String answer;
private Integer answerTokens;
private String userUuid;
private LocalDateTime createTime;
}

View File

@ -0,0 +1,19 @@
package com.moyz.adi.common.dto;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import org.springframework.validation.annotation.Validated;
@Validated
@Data
public class AiSearchReq {
@NotBlank
private String searchText;
private String engineName;
private String modelName;
private boolean briefSearch;
}

View File

@ -0,0 +1,13 @@
package com.moyz.adi.common.dto;
import lombok.Data;
import java.util.List;
@Data
public class AiSearchResp {
private Long minId;
private List<AiSearchRecordResp> records;
}

View File

@ -0,0 +1,9 @@
package com.moyz.adi.common.dto;
import lombok.Data;
@Data
public class GoogleSearchError {
private Integer code;
private String message;
}

View File

@ -0,0 +1,47 @@
package com.moyz.adi.common.dto;
import lombok.Data;
import java.util.List;
@Data
public class GoogleSearchResp {
private String kind;
private Queries queries;
private SearchInformation searchInformation;
private List<Item> items;
private GoogleSearchError error;
@Data
public static class Queries {
private Request[] request;
}
@Data
public static class Request {
private String title;
private String totalResults;
private String searchTerms;
private Integer count;
private Integer startIndex;
private String inputEncoding;
private String outputEncoding;
}
@Data
public static class SearchInformation {
private double searchTime;
private String formattedSearchTime;
private String totalResults;
private String formattedTotalResults;
}
@Data
public static class Item {
private String kind;
private String title;
private String htmlTitle;
private String link;
private String snippet;
}
}

View File

@ -0,0 +1,12 @@
package com.moyz.adi.common.dto;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
@Data
@Accessors(chain = true)
public class SearchEngineResp {
private List<SearchResultItem> items;
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.dto;
import lombok.Data;
import java.util.List;
@Data
public class SearchResult {
private String errorMessage;
private List<SearchResultItem> items;
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.dto;
import lombok.Data;
@Data
public class SearchResultItem {
private String title;
private String link;
private String snippet;
private String content;
}

View File

@ -0,0 +1,50 @@
package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.moyz.adi.common.base.SearchEngineRespTypeHandler;
import com.moyz.adi.common.dto.SearchEngineResp;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.apache.ibatis.type.JdbcType;
@Data
@TableName("adi_ai_search_record")
@Schema(title = "AiSearchRecord对象", description = "AI搜索记录表")
public class AiSearchRecord extends BaseEntity {
@TableField("uuid")
private String uuid;
@Schema(title = "问题")
@TableField("question")
private String question;
@Schema(title = "Search engine's response content")
@TableField(value = "search_engine_response", jdbcType = JdbcType.JAVA_OBJECT, typeHandler = SearchEngineRespTypeHandler.class)
private SearchEngineResp searchEngineResp;
@Schema(title = "最终提供给LLM的提示词")
@TableField("prompt")
private String prompt;
@Schema(title = "提供给LLM的提示词所消耗的token数量")
@TableField("prompt_tokens")
private Integer promptTokens;
@Schema(title = "答案")
@TableField("answer")
private String answer;
@Schema(title = "答案消耗的token")
@TableField("answer_tokens")
private Integer answerTokens;
@Schema(title = "提问用户uuid")
@TableField("user_uuid")
private String userUuid;
@Schema(title = "提问用户id")
@TableField("user_id")
private Long userId;
}

View File

@ -34,7 +34,7 @@ public enum ErrorEnum {
B_MESSAGE_NOT_FOUND("B0008", "消息不存在"), B_MESSAGE_NOT_FOUND("B0008", "消息不存在"),
B_LLM_SERVICE_DISABLED("B0009", "LLM服务不可用"), B_LLM_SERVICE_DISABLED("B0009", "LLM服务不可用"),
B_KNOWLEDGE_BASE_IS_EMPTY("B0010", "知识库内容为空"), B_KNOWLEDGE_BASE_IS_EMPTY("B0010", "知识库内容为空"),
B_KNOWLEDGE_BASE_NO_ANSWER("B0011", "[无答案]") B_NO_ANSWER("B0011", "[无答案]")
; ;
private String code; private String code;

View File

@ -17,10 +17,10 @@ public class QuotaHelper {
private UserDayCostService userDayCostService; private UserDayCostService userDayCostService;
public ErrorEnum checkTextQuota(User user) { public ErrorEnum checkTextQuota(User user) {
if (StringUtils.isNotBlank(user.getSecretKey())) { // if (StringUtils.isNotBlank(user.getSecretKey())) {
log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId()); // log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId());
return null; // return null;
} // }
int userQuotaByTokenDay = user.getQuotaByTokenDaily(); int userQuotaByTokenDay = user.getQuotaByTokenDaily();
int userQuotaByTokenMonth = user.getQuotaByTokenMonthly(); int userQuotaByTokenMonth = user.getQuotaByTokenMonthly();
int userQuotaByRequestDay = user.getQuotaByRequestDaily(); int userQuotaByRequestDay = user.getQuotaByRequestDaily();

View File

@ -1,5 +1,6 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.helper;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.interfaces.TriConsumer; import com.moyz.adi.common.interfaces.TriConsumer;
@ -28,35 +29,45 @@ public class SSEEmitterHelper {
@Resource @Resource
private RateLimitHelper rateLimitHelper; private RateLimitHelper rateLimitHelper;
public void process(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) { public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
SseEmitter sseEmitter = sseAskParams.getSseEmitter(); //Check: rate limit
//rate limit by system
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) { if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
sendErrorMsg(sseEmitter, "访问太过频繁"); sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁");
return; return false;
} }
//Check: If still waiting response //Check: If still waiting response
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
String askingVal = stringRedisTemplate.opsForValue().get(askingKey); String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
if (StringUtils.isNotBlank(askingVal)) { if (StringUtils.isNotBlank(askingVal)) {
sendErrorMsg(sseEmitter, "正在回复中..."); sendErrorAndComplete(user.getId(), sseEmitter, "正在回复中...");
return; return false;
} }
return true;
}
public void startSse(User user, SseEmitter sseEmitter) {
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS); stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
try { try {
sseEmitter.send(SseEmitter.event().name("[START]")); sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START));
} catch (IOException e) { } catch (IOException e) {
log.error("error", e); log.error("error", e);
sseEmitter.completeWithError(e); sseEmitter.completeWithError(e);
stringRedisTemplate.delete(askingKey); stringRedisTemplate.delete(askingKey);
return;
} }
}
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
public void processAndPushToModel(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
sseEmitter.onCompletion(() -> { sseEmitter.onCompletion(() -> {
log.info("response complete,uid:{}", user.getId()); log.info("response complete,uid:{}", user.getId());
}); });
@ -65,7 +76,7 @@ public class SSEEmitterHelper {
throwable -> { throwable -> {
try { try {
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable); log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
sseEmitter.send(SseEmitter.event().name("[ERROR]").data(throwable.getMessage())); sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(throwable.getMessage()));
} catch (IOException e) { } catch (IOException e) {
log.error("error", e); log.error("error", e);
} finally { } finally {
@ -84,22 +95,29 @@ public class SSEEmitterHelper {
}); });
} }
public void sendAndComplete(SseEmitter sseEmitter, String msg){ public void sendAndComplete(long userId, SseEmitter sseEmitter, String msg) {
try { try {
sseEmitter.send(SseEmitter.event().name("[START]")); sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START));
sseEmitter.send(SseEmitter.event().name("[DONE]").data(msg)); sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.DONE).data(msg));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
sseEmitter.complete(); sseEmitter.complete();
delSseRequesting(userId);
} }
public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) { public void sendErrorAndComplete(long userId, SseEmitter sseEmitter, String errorMsg) {
try { try {
sseEmitter.send(SseEmitter.event().name("[ERROR]").data(errorMsg)); sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(errorMsg));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
sseEmitter.complete(); sseEmitter.complete();
delSseRequesting(userId);
}
private void delSseRequesting(long userId) {
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, userId);
stringRedisTemplate.delete(askingKey);
} }
} }

View File

@ -39,13 +39,17 @@ public abstract class AbstractImageModelService<T> {
protected ImageModel imageModel; protected ImageModel imageModel;
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz, Proxy proxy){ public AbstractImageModelService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName; this.modelName = modelName;
this.proxy = proxy;
String st = LocalCache.CONFIGS.get(settingName); String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz); setting = JsonUtil.fromJson(st, clazz);
} }
public AbstractImageModelService setProxy(Proxy proxy) {
this.proxy = proxy;
return this;
}
public ImageModel getImageModel(User user, String size) { public ImageModel getImageModel(User user, String size) {
if (null != imageModel) { if (null != imageModel) {
return imageModel; return imageModel;
@ -56,6 +60,7 @@ public abstract class AbstractImageModelService<T> {
/** /**
* 检测该service是否可用不可用的情况通过是没有配置key * 检测该service是否可用不可用的情况通过是没有配置key
*
* @return * @return
*/ */
public abstract boolean isEnabled(); public abstract boolean isEnabled();

View File

@ -36,13 +36,17 @@ public abstract class AbstractLLMService<T> {
protected StreamingChatLanguageModel streamingChatLanguageModel; protected StreamingChatLanguageModel streamingChatLanguageModel;
protected ChatLanguageModel chatLanguageModel; protected ChatLanguageModel chatLanguageModel;
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy) { public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
this.modelName = modelName; this.modelName = modelName;
this.proxy = proxy;
String st = LocalCache.CONFIGS.get(settingName); String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz); setting = JsonUtil.fromJson(st, clazz);
} }
public AbstractLLMService setProxy(Proxy proxy) {
this.proxy = proxy;
return this;
}
/** /**
* 检测该service是否可用不可用的情况通常是没有配置key * 检测该service是否可用不可用的情况通常是没有配置key
* *
@ -73,7 +77,7 @@ public abstract class AbstractLLMService<T> {
protected abstract String parseError(Object error); protected abstract String parseError(Object error);
public Response<AiMessage> chat(ChatMessage chatMessage) { public Response<AiMessage> chat(ChatMessage chatMessage) {
if(!isEnabled()){ if (!isEnabled()) {
log.error("llm service is disabled"); log.error("llm service is disabled");
throw new BaseException(B_LLM_SERVICE_DISABLED); throw new BaseException(B_LLM_SERVICE_DISABLED);
} }
@ -81,7 +85,7 @@ public abstract class AbstractLLMService<T> {
} }
public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) { public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
if(!isEnabled()){ if (!isEnabled()) {
log.error("llm service is disabled"); log.error("llm service is disabled");
throw new BaseException(B_LLM_SERVICE_DISABLED); throw new BaseException(B_LLM_SERVICE_DISABLED);
} }
@ -131,7 +135,7 @@ public abstract class AbstractLLMService<T> {
log.error("stream error", error); log.error("stream error", error);
try { try {
String errorMsg = parseError(error); String errorMsg = parseError(error);
if(StringUtils.isBlank(errorMsg)){ if (StringUtils.isBlank(errorMsg)) {
errorMsg = error.getMessage(); errorMsg = error.getMessage();
} }
params.getSseEmitter().send(SseEmitter.event().name("[ERROR]").data(errorMsg)); params.getSseEmitter().send(SseEmitter.event().name("[ERROR]").data(errorMsg));

View File

@ -0,0 +1,45 @@
package com.moyz.adi.common.interfaces;
import com.moyz.adi.common.dto.SearchResult;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.SpringUtil;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import java.net.Proxy;
public abstract class AbstractSearchEngine<T> {
protected String engineName;
protected Proxy proxy;
public AbstractSearchEngine(String engineName, String settingName, Class<T> clazz) {
this.engineName = engineName;
String st = LocalCache.CONFIGS.get(settingName);
setting = JsonUtil.fromJson(st, clazz);
}
protected T setting;
public abstract boolean isEnabled();
public abstract SearchResult search(String searchTxt);
public AbstractSearchEngine setProxy(Proxy proxy) {
this.proxy = proxy;
return this;
}
protected RestTemplate getRestTemplate() {
RestTemplate restTemplate = SpringUtil.getBean(RestTemplate.class);
if (null != proxy) {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
requestFactory.setProxy(proxy);
restTemplate.setRequestFactory(requestFactory);
}
return restTemplate;
}
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.interfaces;
public abstract class AbstractSearchService {
public abstract boolean isEnabled();
public abstract void briefSearch(String question);
public abstract void detailSearch(String question);
}

View File

@ -0,0 +1,9 @@
package com.moyz.adi.common.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.moyz.adi.common.entity.AiSearchRecord;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface AiSearchRecordMapper extends BaseMapper<AiSearchRecord> {
}

View File

@ -0,0 +1,56 @@
package com.moyz.adi.common.searchengine;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.dto.GoogleSearchResp;
import com.moyz.adi.common.dto.SearchResult;
import com.moyz.adi.common.dto.SearchResultItem;
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
import com.moyz.adi.common.util.MPPageUtil;
import com.moyz.adi.common.vo.GoogleSetting;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.List;
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.GOOGLE_SETTING;
@Slf4j
public class GoogleSearchEngine extends AbstractSearchEngine<GoogleSetting> {
public GoogleSearchEngine() {
super(AdiConstant.SearchEngineName.GOOGLE, GOOGLE_SETTING, GoogleSetting.class);
}
@Override
public boolean isEnabled() {
return StringUtils.isNoneBlank(setting.getKey(), setting.getCx());
}
@Override
public SearchResult search(String searchTxt) {
SearchResult result = new SearchResult();
List<SearchResultItem> items = new ArrayList<>();
try {
ResponseEntity<GoogleSearchResp> resp = getRestTemplate().getForEntity(MessageFormat.format("{0}?key={1}&cx={2}&q={3}", setting.getUrl(), setting.getKey(), setting.getCx(), searchTxt), GoogleSearchResp.class);
if (null != resp && HttpStatus.OK.isSameCodeAs(resp.getStatusCode())) {
GoogleSearchResp googleSearchResp = resp.getBody();
if (null != googleSearchResp.getError()) {
log.error("google search error,code:{},message:{}", googleSearchResp.getError().getCode(), googleSearchResp.getError().getMessage());
result.setErrorMessage(googleSearchResp.getError().getMessage());
} else {
log.info("google response:{}", resp);
items = MPPageUtil.convertTo(googleSearchResp.getItems(), SearchResultItem.class);
}
}
} catch (Exception e) {
log.error("google search error", e);
}
result.setItems(items);
return result;
}
}

View File

@ -0,0 +1,40 @@
package com.moyz.adi.common.searchengine;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
import com.moyz.adi.common.vo.SearchEngineInfo;
import lombok.extern.slf4j.Slf4j;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Search engine context. strategy design model
*/
@Slf4j
public class SearchEngineContext {
public static final Map<String, SearchEngineInfo> NAME_TO_ENGINE = new LinkedHashMap<>();
private AbstractSearchEngine iSearchEngine;
public SearchEngineContext(String searchEngineName) {
if (null == NAME_TO_ENGINE.get(searchEngineName)) {
log.warn("︿︿︿ Can not find {}, use the default engine GOOGLE ︿︿︿", searchEngineName);
iSearchEngine = NAME_TO_ENGINE.get(AdiConstant.SearchEngineName.GOOGLE).getEngine();
} else {
iSearchEngine = NAME_TO_ENGINE.get(searchEngineName).getEngine();
}
}
public static void addEngine(String engineName, AbstractSearchEngine searchEngine) {
SearchEngineInfo info = new SearchEngineInfo();
info.setName(engineName);
info.setEnable(searchEngine.isEnabled());
info.setEngine(searchEngine);
NAME_TO_ENGINE.put(engineName, info);
}
public AbstractSearchEngine getEngine() {
return iSearchEngine;
}
}

View File

@ -0,0 +1,82 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.AiSearchRecordResp;
import com.moyz.adi.common.dto.AiSearchResp;
import com.moyz.adi.common.dto.SearchEngineResp;
import com.moyz.adi.common.entity.AiSearchRecord;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.mapper.AiSearchRecordMapper;
import com.moyz.adi.common.util.BizPager;
import com.moyz.adi.common.util.MPPageUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
/**
* Ai search
*/
@Slf4j
@Service
public class AiSearchRecordService extends ServiceImpl<AiSearchRecordMapper, AiSearchRecord> {
/**
* List search records
*
* @param maxId Anchor id
* @param keyword user's question
* @return
*/
public AiSearchResp listByMaxId(Long maxId, String keyword) {
LambdaQueryWrapper<AiSearchRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId());
wrapper.eq(AiSearchRecord::getIsDeleted, false);
if (StringUtils.isNotBlank(keyword)) {
wrapper.like(AiSearchRecord::getQuestion, keyword);
}
AiSearchResp result = new AiSearchResp();
BizPager.listByMaxId(maxId, wrapper, this, AiSearchRecord::getId, (recordList, minId) -> {
List<AiSearchRecordResp> list = MPPageUtil.convertTo(recordList, AiSearchRecordResp.class);
list.forEach(item -> {
if(null == item.getSearchEngineResp()){
SearchEngineResp searchEngineResp = new SearchEngineResp();
searchEngineResp.setItems(new ArrayList<>());
item.setSearchEngineResp(searchEngineResp);
}
});
result.setRecords(list);
result.setMinId(minId);
});
return result;
}
public boolean softDelete(String uuid) {
if (ThreadContext.getCurrentUser().getIsAdmin()) {
return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(AiSearchRecord::getUuid, uuid)
.set(AiSearchRecord::getIsDeleted, true)
.update();
}
AiSearchRecord exist = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(AiSearchRecord::getUuid, uuid)
.eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId())
.one();
if (null == exist) {
throw new BaseException(A_DATA_NOT_FOUND);
}
return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(AiSearchRecord::getId, exist.getId())
.set(AiSearchRecord::getIsDeleted, true)
.update();
}
}

View File

@ -8,7 +8,6 @@ import com.moyz.adi.common.dto.AskReq;
import com.moyz.adi.common.entity.Conversation; import com.moyz.adi.common.entity.Conversation;
import com.moyz.adi.common.entity.ConversationMessage; import com.moyz.adi.common.entity.ConversationMessage;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.enums.ChatMessageRoleEnum; import com.moyz.adi.common.enums.ChatMessageRoleEnum;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
@ -16,8 +15,6 @@ import com.moyz.adi.common.helper.QuotaHelper;
import com.moyz.adi.common.helper.SSEEmitterHelper; import com.moyz.adi.common.helper.SSEEmitterHelper;
import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.mapper.ConversationMessageMapper;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.util.UserUtil;
import com.moyz.adi.common.vo.AnswerMeta; import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.PromptMeta; import com.moyz.adi.common.vo.PromptMeta;
import com.moyz.adi.common.vo.SseAskParams; import com.moyz.adi.common.vo.SseAskParams;
@ -66,11 +63,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
public SseEmitter sseAsk(AskReq askReq) { public SseEmitter sseAsk(AskReq askReq) {
SseEmitter sseEmitter = new SseEmitter(); SseEmitter sseEmitter = new SseEmitter();
User user = ThreadContext.getCurrentUser();
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
return sseEmitter;
}
sseEmitterHelper.startSse(user, sseEmitter);
_this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq); _this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
return sseEmitter; return sseEmitter;
} }
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) { private boolean checkConversation(SseEmitter sseEmitter, User user, AskReq askReq) {
try { try {
//check 1: the conversation has been deleted //check 1: the conversation has been deleted
@ -79,7 +81,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
.eq(Conversation::getIsDeleted, true) .eq(Conversation::getIsDeleted, true)
.one(); .one();
if (null != delConv) { if (null != delConv) {
sseEmitterHelper.sendErrorMsg(sseEmitter, "该对话已经删除"); sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, "该对话已经删除");
return false; return false;
} }
@ -90,14 +92,14 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
.count(); .count();
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM)); long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
if (convsCount >= convsMax) { if (convsCount >= convsMax) {
sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax); sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
return false; return false;
} }
//check 3: current user's quota //check 3: current user's quota
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user); ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
if (null != errorMsg) { if (null != errorMsg) {
sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo()); sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, errorMsg.getInfo());
return false; return false;
} }
} catch (Exception e) { } catch (Exception e) {
@ -112,7 +114,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) { public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
log.info("asyncCheckAndPushToClient,userId:{}", user.getId()); log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
//check business rules //check business rules
if (!check(sseEmitter, user, askReq)) { if (!checkConversation(sseEmitter, user, askReq)) {
return; return;
} }
@ -161,7 +163,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
} }
sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> { sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
}); });
} }

View File

@ -22,7 +22,7 @@ import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> { public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
public DashScopeLLMService(String modelName) { public DashScopeLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class, null); super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class);
} }
@Override @Override

View File

@ -3,6 +3,8 @@ package com.moyz.adi.common.service;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.helper.ImageModelContext; import com.moyz.adi.common.helper.ImageModelContext;
import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.searchengine.GoogleSearchEngine;
import com.moyz.adi.common.searchengine.SearchEngineContext;
import dev.langchain4j.model.openai.OpenAiModelName; import dev.langchain4j.model.openai.OpenAiModelName;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -43,16 +45,16 @@ public class Initializer {
//openai //openai
String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING); String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING);
if(openaiModels.length == 0){ if (openaiModels.length == 0) {
log.warn("openai service is disabled"); log.warn("openai service is disabled");
} }
for (String model : openaiModels) { for (String model : openaiModels) {
LLMContext.addLLMService(model, new OpenAiLLMService(model, proxy)); LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy));
} }
//dashscope //dashscope
String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING); String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING);
if(dashscopeModels.length == 0){ if (dashscopeModels.length == 0) {
log.warn("dashscope service is disabled"); log.warn("dashscope service is disabled");
} }
for (String model : dashscopeModels) { for (String model : dashscopeModels) {
@ -61,7 +63,7 @@ public class Initializer {
//qianfan //qianfan
String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING); String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING);
if(qianfanModels.length == 0){ if (qianfanModels.length == 0) {
log.warn("qianfan service is disabled"); log.warn("qianfan service is disabled");
} }
for (String model : qianfanModels) { for (String model : qianfanModels) {
@ -70,14 +72,18 @@ public class Initializer {
//ollama //ollama
String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING); String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING);
if(ollamaModels.length == 0){ if (ollamaModels.length == 0) {
log.warn("ollama service is disabled"); log.warn("ollama service is disabled");
} }
for (String model : ollamaModels) { for (String model : ollamaModels) {
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model)); LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model));
} }
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2, proxy)); ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
//search engine
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
ragService.init(); ragService.init();

View File

@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.dto.KbItemEditReq; import com.moyz.adi.common.dto.KbItemEditReq;
import com.moyz.adi.common.entity.KnowledgeBase; import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseItem; import com.moyz.adi.common.entity.KnowledgeBaseItem;
@ -24,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import static com.moyz.adi.common.cosntant.AdiConstant.RAG_TYPE_KB;
import static com.moyz.adi.common.enums.ErrorEnum.*; import static com.moyz.adi.common.enums.ErrorEnum.*;
@Slf4j @Slf4j
@ -112,8 +114,8 @@ public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMappe
knowledgeBaseEmbeddingService.deleteByItemUuid(kbItem.getUuid()); knowledgeBaseEmbeddingService.deleteByItemUuid(kbItem.getUuid());
Metadata metadata = new Metadata(); Metadata metadata = new Metadata();
metadata.add("kb_uuid", kbItem.getKbUuid()); metadata.add(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbItem.getKbUuid());
metadata.add("kb_item_uuid", kbItem.getUuid()); metadata.add(AdiConstant.EmbeddingMetadataKey.KB_ITEM_UUID, kbItem.getUuid());
Document document = new Document(kbItem.getRemark(), metadata); Document document = new Document(kbItem.getRemark(), metadata);
ragService.ingest(document); ragService.ingest(document);

View File

@ -66,7 +66,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
return baseMapper.selectOne(wrapper); return baseMapper.selectOne(wrapper);
} }
public boolean softDelele(String uuid) { public boolean softDelete(String uuid) {
if (ThreadContext.getCurrentUser().getIsAdmin()) { if (ThreadContext.getCurrentUser().getIsAdmin()) {
return ChainWrappers.lambdaUpdateChain(baseMapper) return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(KnowledgeBaseQaRecord::getUuid, uuid) .eq(KnowledgeBaseQaRecord::getUuid, uuid)

View File

@ -4,7 +4,9 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.google.common.collect.ImmutableMap;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.KbEditReq; import com.moyz.adi.common.dto.KbEditReq;
import com.moyz.adi.common.dto.QAReq; import com.moyz.adi.common.dto.QAReq;
@ -39,6 +41,7 @@ import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES; import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.QUOTA_BY_QA_ASK_DAILY;
import static com.moyz.adi.common.enums.ErrorEnum.*; import static com.moyz.adi.common.enums.ErrorEnum.*;
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument; import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
@ -161,9 +164,8 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
//向量化 //向量化
Document docWithoutPath = new Document(document.text()); Document docWithoutPath = new Document(document.text());
docWithoutPath.metadata() docWithoutPath.metadata()
.add("kb_uuid", knowledgeBase.getUuid()) .add(AdiConstant.EmbeddingMetadataKey.KB_UUID, knowledgeBase.getUuid())
.add("kb_item_uuid", knowledgeBaseItem.getUuid()); .add(AdiConstant.EmbeddingMetadataKey.KB_ITEM_UUID, knowledgeBaseItem.getUuid());
ragService.ingest(docWithoutPath); ragService.ingest(docWithoutPath);
knowledgeBaseItemService knowledgeBaseItemService
@ -205,7 +207,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
LambdaQueryWrapper<KnowledgeBase> wrapper = new LambdaQueryWrapper(); LambdaQueryWrapper<KnowledgeBase> wrapper = new LambdaQueryWrapper();
wrapper.eq(KnowledgeBase::getIsPublic, true); wrapper.eq(KnowledgeBase::getIsPublic, true);
wrapper.eq(KnowledgeBase::getIsDeleted, false); wrapper.eq(KnowledgeBase::getIsDeleted, false);
if(StringUtils.isNotBlank(keyword)){ if (StringUtils.isNotBlank(keyword)) {
wrapper.like(KnowledgeBase::getTitle, keyword); wrapper.like(KnowledgeBase::getTitle, keyword);
} }
wrapper.orderByDesc(KnowledgeBase::getStarCount, KnowledgeBase::getUpdateTime); wrapper.orderByDesc(KnowledgeBase::getStarCount, KnowledgeBase::getUpdateTime);
@ -223,7 +225,8 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) { public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
checkRequestTimesOrThrow(); checkRequestTimesOrThrow();
KnowledgeBase knowledgeBase = getOrThrow(kbUuid); KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName); Map<String, String> metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid);
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(metadataCond, question, modelName);
Response<AiMessage> ar = responsePair.getRight(); Response<AiMessage> ar = responsePair.getRight();
int inputTokenCount = ar.tokenUsage().inputTokenCount(); int inputTokenCount = ar.tokenUsage().inputTokenCount();
@ -235,7 +238,12 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
public SseEmitter sseAsk(String kbUuid, QAReq req) { public SseEmitter sseAsk(String kbUuid, QAReq req) {
checkRequestTimesOrThrow(); checkRequestTimesOrThrow();
SseEmitter sseEmitter = new SseEmitter(); SseEmitter sseEmitter = new SseEmitter();
_this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req); User user = ThreadContext.getCurrentUser();
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
return sseEmitter;
}
sseEmitterHelper.startSse(user, sseEmitter);
_this.retrieveAndPushToLLM(user, sseEmitter, kbUuid, req);
return sseEmitter; return sseEmitter;
} }
@ -253,7 +261,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
private void checkRequestTimesOrThrow() { private void checkRequestTimesOrThrow() {
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd")); String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
String askTimes = stringRedisTemplate.opsForValue().get(key); String askTimes = stringRedisTemplate.opsForValue().get(key);
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily"); String askQuota = SysConfigService.getByKey(QUOTA_BY_QA_ASK_DAILY);
if (null != askTimes && null != askQuota && Integer.parseInt(askTimes) >= Integer.parseInt(askQuota)) { if (null != askTimes && null != askQuota && Integer.parseInt(askTimes) >= Integer.parseInt(askQuota)) {
throw new BaseException(A_QA_ASK_LIMIT); throw new BaseException(A_QA_ASK_LIMIT);
} }
@ -265,10 +273,10 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) { public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId()); log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
KnowledgeBase knowledgeBase = getOrThrow(kbUuid); KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
Map<String, String> metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid);
Prompt prompt = ragService.retrieveAndCreatePrompt(kbUuid, req.getQuestion()); Prompt prompt = ragService.retrieveAndCreatePrompt(metadataCond, req.getQuestion());
if(null == prompt){ if (null == prompt) {
sseEmitterHelper.sendAndComplete(sseEmitter, B_KNOWLEDGE_BASE_NO_ANSWER.getInfo()); sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo());
return; return;
} }
String promptText = prompt.text(); String promptText = prompt.text();
@ -277,7 +285,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
sseAskParams.setSseEmitter(sseEmitter); sseAskParams.setSseEmitter(sseEmitter);
sseAskParams.setUserMessage(promptText); sseAskParams.setUserMessage(promptText);
sseAskParams.setModelName(req.getModelName()); sseAskParams.setModelName(req.getModelName());
sseEmitterHelper.process(user, sseAskParams, (response, promptMeta, answerMeta) -> { sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens()); knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens());
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens()); userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
}); });

View File

@ -7,14 +7,13 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel; import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.ollama.OllamaStreamingChatModel; import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTING; import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTING;
public class OllamaLLMService extends AbstractLLMService<OllamaSetting> { public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
public OllamaLLMService(String modelName) { public OllamaLLMService(String modelName) {
super(modelName, OLLAMA_SETTING, OllamaSetting.class, null); super(modelName, OLLAMA_SETTING, OllamaSetting.class);
} }
@Override @Override

View File

@ -45,8 +45,8 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
@Resource @Resource
private ObjectMapper objectMapper; private ObjectMapper objectMapper;
public OpenAiImageModelService(String modelName, Proxy proxy) { public OpenAiImageModelService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy); super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
} }
@Override @Override

View File

@ -17,7 +17,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings; import org.apache.logging.log4j.util.Strings;
import java.net.Proxy;
import java.time.Duration; import java.time.Duration;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
@ -28,8 +27,8 @@ import java.time.temporal.ChronoUnit;
@Accessors(chain = true) @Accessors(chain = true)
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> { public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
public OpenAiLLMService(String modelName, Proxy proxy) { public OpenAiLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy); super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
} }
@Override @Override

View File

@ -19,7 +19,7 @@ import org.apache.commons.lang3.StringUtils;
public class QianFanLLMService extends AbstractLLMService<QianFanSetting> { public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
public QianFanLLMService(String modelName) { public QianFanLLMService(String modelName) {
super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class, null); super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class);
} }
@Override @Override

View File

@ -1,10 +1,7 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.moyz.adi.common.helper.LLMContext; import com.moyz.adi.common.helper.LLMContext;
import com.moyz.adi.common.interfaces.TriConsumer;
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore; import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.PromptMeta;
import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter; import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters; import dev.langchain4j.data.document.splitter.DocumentSplitters;
@ -14,7 +11,6 @@ import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiTokenizer; import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
@ -24,35 +20,37 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static com.moyz.adi.common.cosntant.AdiConstant.PROMPT_TEMPLATE;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.joining;
@Slf4j @Slf4j
@Service
public class RAGService { public class RAGService {
@Value("${spring.datasource.url}")
private String dataBaseUrl; private String dataBaseUrl;
@Value("${spring.datasource.username}")
private String dataBaseUserName; private String dataBaseUserName;
@Value("${spring.datasource.password}")
private String dataBasePassword; private String dataBasePassword;
private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}");
private String tableName;
private EmbeddingModel embeddingModel; private EmbeddingModel embeddingModel;
private EmbeddingStore<TextSegment> embeddingStore; private EmbeddingStore<TextSegment> embeddingStore;
public RAGService(String tableName, String dataBaseUrl, String dataBaseUserName, String dataBasePassword) {
this.tableName = tableName;
this.dataBasePassword = dataBasePassword;
this.dataBaseUserName = dataBaseUserName;
this.dataBaseUrl = dataBaseUrl;
}
public void init() { public void init() {
log.info("initEmbeddingModel"); log.info("initEmbeddingModel");
embeddingModel = new AllMiniLmL6V2EmbeddingModel(); embeddingModel = new AllMiniLmL6V2EmbeddingModel();
@ -88,7 +86,7 @@ public class RAGService {
.dimension(384) .dimension(384)
.createTable(true) .createTable(true)
.dropTableFirst(false) .dropTableFirst(false)
.table("adi_knowledge_base_embedding") .table(tableName)
.build(); .build();
return embeddingStore; return embeddingStore;
} }
@ -112,7 +110,14 @@ public class RAGService {
getEmbeddingStoreIngestor().ingest(document); getEmbeddingStoreIngestor().ingest(document);
} }
public Prompt retrieveAndCreatePrompt(String kbUuid, String question) { /**
* Retrieve documents and create prompt
*
* @param metadataCond Query condition
* @param question User's question
* @return Document in the vector db
*/
public Prompt retrieveAndCreatePrompt(Map<String, String> metadataCond, String question) {
// Embed the question // Embed the question
Embedding questionEmbedding = embeddingModel.embed(question).content(); Embedding questionEmbedding = embeddingModel.embed(question).content();
@ -120,7 +125,7 @@ public class RAGService {
// You can play with parameters below to find a sweet spot for your specific use case // You can play with parameters below to find a sweet spot for your specific use case
int maxResults = 3; int maxResults = 3;
double minScore = 0.6; double minScore = 0.6;
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByKbUuid(kbUuid, questionEmbedding, maxResults, minScore); List<EmbeddingMatch<TextSegment>> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByMetadata(metadataCond, questionEmbedding, maxResults, minScore);
// Create a prompt for the model that includes question and relevant embeddings // Create a prompt for the model that includes question and relevant embeddings
String information = relevantEmbeddings.stream() String information = relevantEmbeddings.stream()
@ -130,23 +135,28 @@ public class RAGService {
if (StringUtils.isBlank(information)) { if (StringUtils.isBlank(information)) {
return null; return null;
} }
return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))); return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
} }
/** /**
* 召回并提问 * 召回并提问
* *
* @param kbUuid 知识库uuid * @param metadataCond metadata condition
* @param question 用户的问题 * @param question user's question
* @param modelName LLM model name * @param modelName LLM model name
* @return * @return
*/ */
public Pair<String, Response<AiMessage>> retrieveAndAsk(String kbUuid, String question, String modelName) { public Pair<String, Response<AiMessage>> retrieveAndAsk(Map<String, String> metadataCond, String question, String modelName) {
Prompt prompt = retrieveAndCreatePrompt(kbUuid, question);
Prompt prompt = retrieveAndCreatePrompt(metadataCond, question);
if (null == prompt) { if (null == prompt) {
return null; return null;
} }
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage()); Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
return new ImmutablePair<>(prompt.text(), response); return new ImmutablePair<>(prompt.text(), response);
} }
public static final String parsePromptTemplate(String question, String information) {
return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))).text();
}
} }

View File

@ -0,0 +1,250 @@
package com.moyz.adi.common.service;
import com.google.common.collect.ImmutableMap;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.dto.SearchEngineResp;
import com.moyz.adi.common.dto.SearchResult;
import com.moyz.adi.common.dto.SearchResultItem;
import com.moyz.adi.common.entity.AiSearchRecord;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.helper.SSEEmitterHelper;
import com.moyz.adi.common.searchengine.SearchEngineContext;
import com.moyz.adi.common.vo.SseAskParams;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.model.input.Prompt;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.jsoup.Jsoup;
import org.jsoup.safety.Cleaner;
import org.jsoup.safety.Safelist;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import static com.moyz.adi.common.enums.ErrorEnum.B_NO_ANSWER;
/**
* RAG search
*/
@Slf4j
@Service
public class SearchService {
@Lazy
@Resource
private SearchService _this;
@Resource
private RAGService searchRagService;
@Resource
private SSEEmitterHelper sseEmitterHelper;
@Resource
private AiSearchRecordService aiSearchRecordService;
@Resource
private AsyncTaskExecutor mainExecutor;
public SseEmitter search(boolean isBriefSearch, String searchText, String engineName, String modelName) {
User user = ThreadContext.getCurrentUser();
SseEmitter sseEmitter = new SseEmitter();
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
return sseEmitter;
}
sseEmitterHelper.startSse(user, sseEmitter);
_this.asyncSearch(user, sseEmitter, isBriefSearch, searchText, engineName, modelName);
return sseEmitter;
}
@Async
public void asyncSearch(User user, SseEmitter sseEmitter, boolean isBriefSearch, String searchText, String engineName, String modelName) {
SearchResult searchResult = new SearchEngineContext(engineName).getEngine().search(searchText);
if (StringUtils.isNotBlank(searchResult.getErrorMessage())) {
sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, searchResult.getErrorMessage());
return;
}
if (CollectionUtils.isEmpty(searchResult.getItems())) {
sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo());
return;
}
boolean sendFail = false;
try {
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.AI_SEARCH_SOURCE_LINKS).data(searchResult.getItems()));
} catch (IOException e) {
sendFail = true;
log.error("asyncSearch error", e);
sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, e.getMessage());
}
if (sendFail) {
return;
}
if (isBriefSearch) {
briefSearch(user, searchText, modelName, searchResult.getItems(), sseEmitter);
} else {
detailSearch(user, searchText, engineName, modelName, searchResult.getItems(), sseEmitter);
}
}
/**
* 1.Search by search engine
* 2.Create prompt by search response
* 3.Send prompt to llm
*
* @param user
* @param searchText
* @param modelName
* @param resultItems
* @param sseEmitter
*/
public void briefSearch(User user, String searchText, String modelName, List<SearchResultItem> resultItems, SseEmitter sseEmitter) {
log.info("briefSearch,searchText:{}", searchText);
StringBuilder builder = new StringBuilder();
for (SearchResultItem item : resultItems) {
builder.append(item.getSnippet()).append("\n\n");
}
String ragQuestion = builder.toString();
String prompt = RAGService.parsePromptTemplate(searchText, ragQuestion);
SearchEngineResp resp = new SearchEngineResp().setItems(resultItems);
SseAskParams sseAskParams = new SseAskParams();
sseAskParams.setSystemMessage(StringUtils.EMPTY);
sseAskParams.setSseEmitter(sseEmitter);
sseAskParams.setUserMessage(prompt);
sseAskParams.setModelName(modelName);
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
AiSearchRecord newRecord = new AiSearchRecord();
newRecord.setUuid(UUID.randomUUID().toString().replace("-", ""));
newRecord.setQuestion(searchText);
newRecord.setSearchEngineResp(resp);
newRecord.setPrompt(prompt);
newRecord.setPromptTokens(promptMeta.getTokens());
newRecord.setAnswer(response);
newRecord.setAnswerTokens(answerMeta.getTokens());
newRecord.setUserUuid(user.getUuid());
newRecord.setUserId(user.getId());
aiSearchRecordService.save(newRecord);
});
}
/**
* 1.Search by search engine
* 2.Save the response to pgvector
* 3.Retrieve document and create prompt
* 4.Send prompt to llm
*
* @param user
* @param searchText
* @param engineName
* @param modelName
* @param resultItems
* @param sseEmitter
*/
public void detailSearch(User user, String searchText, String engineName, String modelName, List<SearchResultItem> resultItems, SseEmitter sseEmitter) {
log.info("detailSearch,searchText:{}", searchText);
//Save to DB
SearchEngineResp resp = new SearchEngineResp().setItems(resultItems);
AiSearchRecord newRecord = new AiSearchRecord();
String searchUuid = UUID.randomUUID().toString().replace("-", "");
newRecord.setUuid(searchUuid);
newRecord.setQuestion(searchText);
newRecord.setSearchEngineResp(resp);
newRecord.setUserId(user.getId());
newRecord.setUserUuid(user.getUuid());
aiSearchRecordService.save(newRecord);
CountDownLatch countDownLatch = new CountDownLatch(resultItems.size());
for (int i = 0; i < resultItems.size(); i++) {
int finalI = i;
mainExecutor.execute(() -> {
try {
SearchResultItem item = resultItems.get(finalI);
String content;
if (finalI < 2) {
content = getContentFromRemote(item);
//Fill content with html body text
item.setContent(content);
} else {
content = item.getSnippet();
}
//embedding
Metadata metadata = new Metadata();
metadata.add(AdiConstant.EmbeddingMetadataKey.ENGINE_NAME, engineName);
metadata.add(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid);
Document document = new Document(content, metadata);
searchRagService.ingest(document);
} catch (Exception e) {
log.error("Detail search error,uuid:{}", searchUuid, e);
} finally {
countDownLatch.countDown();
}
});
}
try {
countDownLatch.await();
} catch (InterruptedException e) {
log.error("CountDownLatch await error,uuid:{}", searchUuid, e);
throw new RuntimeException(e);
}
log.info("Create prompt");
Prompt prompt = searchRagService.retrieveAndCreatePrompt(ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid), searchText);
SseAskParams sseAskParams = new SseAskParams();
sseAskParams.setSystemMessage(StringUtils.EMPTY);
sseAskParams.setSseEmitter(sseEmitter);
sseAskParams.setUserMessage(prompt.text());
sseAskParams.setModelName(modelName);
log.info("Push to model");
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
AiSearchRecord existRecord = aiSearchRecordService.lambdaQuery().eq(AiSearchRecord::getUuid, searchUuid).one();
AiSearchRecord updateRecord = new AiSearchRecord();
updateRecord.setId(existRecord.getId());
//Update search engine response content.(with html body text)
updateRecord.setSearchEngineResp(new SearchEngineResp().setItems(resultItems));
updateRecord.setPrompt(prompt.text());
updateRecord.setPromptTokens(promptMeta.getTokens());
updateRecord.setAnswer(response);
updateRecord.setAnswerTokens(answerMeta.getTokens());
aiSearchRecordService.updateById(updateRecord);
});
}
private String getContentFromRemote(SearchResultItem item) {
String result = "";
try {
org.jsoup.nodes.Document doc = Jsoup.connect(item.getLink()).ignoreContentType(true).get();
if (doc.getElementsByTag("main").size() > 0) {
result = doc.getElementsByTag("main").get(0).html();
} else {
result = doc.body().html();
}
if (StringUtils.isBlank(result)) {
log.error("Empty content from {}, use snippet instead", item.getLink());
return item.getSnippet();
}
} catch (Exception e) {
log.error("Failed to load document from {}, use snippet instead", item.getLink(), e);
}
Cleaner cleaner = new Cleaner(Safelist.none());
return cleaner.clean(Jsoup.parse(result)).text();
}
}

View File

@ -9,6 +9,7 @@ import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import lombok.Builder; import lombok.Builder;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils; import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -286,16 +287,24 @@ public class AdiPgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
} }
} }
//adi //adi
public List<EmbeddingMatch<TextSegment>> findRelevantByKbUuid(String kbUuid, Embedding referenceEmbedding, int maxResults, double minScore) { public List<EmbeddingMatch<TextSegment>> findRelevantByMetadata(Map<String, String> metadatCondition, Embedding referenceEmbedding, int maxResults, double minScore) {
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>(); List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
try (Connection connection = setupConnection()) { try (Connection connection = setupConnection()) {
String referenceVector = Arrays.toString(referenceEmbedding.vector()); String referenceVector = Arrays.toString(referenceEmbedding.vector());
//新增查询条件kb_id //deal with metadata condition
StringBuilder whereSql = new StringBuilder();
if (null != metadatCondition && !metadatCondition.isEmpty()) {
whereSql = new StringBuilder("where");
for (String key : metadatCondition.keySet()) {
whereSql.append(" metadata->>'").append(key).append("' = '").append(metadatCondition.get(key)).append("' and");
}
whereSql.replace(whereSql.length() - 3, whereSql.length(), "");
}
String query = String.format( String query = String.format(
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s where metadata->>'kb_uuid' = '%s') SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", "WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s " + whereSql + ") SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
referenceVector, table, kbUuid, minScore, maxResults); referenceVector, table, minScore, maxResults);
log.info(query);
PreparedStatement selectStmt = connection.prepareStatement(query); PreparedStatement selectStmt = connection.prepareStatement(query);
ResultSet resultSet = selectStmt.executeQuery(); ResultSet resultSet = selectStmt.executeQuery();

View File

@ -10,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.ObjLongConsumer;
import java.util.function.Supplier; import java.util.function.Supplier;
public class BizPager { public class BizPager {
@ -68,6 +69,21 @@ public class BizPager {
} while (!Thread.currentThread().isInterrupted() && records.size() == AdiConstant.DEFAULT_PAGE_SIZE); } while (!Thread.currentThread().isInterrupted() && records.size() == AdiConstant.DEFAULT_PAGE_SIZE);
} }
public static <T extends BaseEntity> void listByMaxId(Long maxId, LambdaQueryWrapper<T> queryWrapper, IService<T> service, SFunction<T, Long> idSupplier, ObjLongConsumer<List<T>> consumer) {
if (maxId > 0) {
queryWrapper.lt(idSupplier, maxId);
}
queryWrapper.orderByDesc(idSupplier);
queryWrapper.last("limit " + AdiConstant.DEFAULT_PAGE_SIZE);
long minId = 0;
List<T> records = service.list(queryWrapper);
if (CollectionUtils.isNotEmpty(records)) {
minId = records.stream().map(idSupplier).reduce(Long::min).get();
}
consumer.accept(records, minId);
}
/** /**
* 以Long类型的惟一字段通常为id为锚点按页获取数据 * 以Long类型的惟一字段通常为id为锚点按页获取数据
* <br/>不依赖mybatis-plus * <br/>不依赖mybatis-plus

View File

@ -0,0 +1,25 @@
package com.moyz.adi.common.util;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
@Component
public class SpringUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
SpringUtil.applicationContext = applicationContext;
}
public static <T> T getBean(String name) {
return (T) applicationContext.getBean(name);
}
public static <T> T getBean(Class<T> clazz) {
return applicationContext.getBean(clazz);
}
}

View File

@ -0,0 +1,10 @@
package com.moyz.adi.common.vo;
import lombok.Data;
@Data
public class GoogleSetting {
private String url;
private String key;
private String cx;
}

View File

@ -0,0 +1,13 @@
package com.moyz.adi.common.vo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
import lombok.Data;
@Data
public class SearchEngineInfo {
private String name;
private Boolean enable;
@JsonIgnore
private AbstractSearchEngine engine;
}

View File

@ -390,6 +390,9 @@ VALUES ('qianfan_setting', '{"api_key":"","secret_key":"","models":[]}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('ollama_setting', '{"base_url":"","models":[]}'); VALUES ('ollama_setting', '{"base_url":"","models":[]}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('google_setting',
'{"url":"https://www.googleapis.com/customsearch/v1","key":"","cx":""}');
INSERT INTO adi_sys_config (name, value)
VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}'); VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}');
INSERT INTO adi_sys_config (name, value) INSERT INTO adi_sys_config (name, value)
VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}'); VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}');
@ -544,3 +547,48 @@ create trigger trigger_kb_qa_record_update_time
on adi_knowledge_base_qa_record on adi_knowledge_base_qa_record
for each row for each row
execute procedure update_modified_column(); execute procedure update_modified_column();
-- ai search
create table adi_ai_search_record
(
id bigserial primary key,
uuid varchar(32) default ''::character varying not null,
question varchar(1000) default ''::character varying not null,
search_engine_response jsonb not null,
prompt text default ''::character varying not null,
prompt_tokens integer DEFAULT 0 NOT NULL,
answer text default ''::character varying not null,
answer_tokens integer DEFAULT 0 NOT NULL,
user_id bigint default '0' NOT NULL,
user_uuid varchar(32) default ''::character varying not null,
create_time timestamp default CURRENT_TIMESTAMP not null,
update_time timestamp default CURRENT_TIMESTAMP not null,
is_deleted boolean default false not null
);
comment on table adi_ai_search_record is 'Search record';
comment on column adi_ai_search_record.question is 'User original question';
comment on column adi_ai_search_record.search_engine_response is 'Search engine''s response content';
comment on column adi_ai_search_record.prompt is 'Prompt of LLM';
comment on column adi_ai_search_record.prompt_tokens is 'prompt消耗的token数量';
comment on column adi_ai_search_record.answer is 'LLM response';
comment on column adi_ai_search_record.answer_tokens is 'LLM响应消耗的token数量';
comment on column adi_ai_search_record.user_id is 'Id from adi_user';
comment on column adi_ai_search_record.create_time is '创建时间';
comment on column adi_ai_search_record.update_time is '更新时间';
comment on column adi_ai_search_record.is_deleted is '0: Normal; 1: Deleted';
create trigger trigger_ai_search_record
before update
on adi_ai_search_record
for each row
execute procedure update_modified_column();