add ai search
This commit is contained in:
parent
abd9ced7cc
commit
86a70e09ca
36
README.md
36
README.md
|
@ -17,7 +17,9 @@
|
||||||
* 提示词
|
* 提示词
|
||||||
* 额度控制
|
* 额度控制
|
||||||
* 基于大模型的知识库(RAG)
|
* 基于大模型的知识库(RAG)
|
||||||
|
* 基于大模型的搜索(RAG)
|
||||||
* 多模型随意切换
|
* 多模型随意切换
|
||||||
|
* 多搜索引擎随意切换
|
||||||
|
|
||||||
## 接入的模型:
|
## 接入的模型:
|
||||||
|
|
||||||
|
@ -27,6 +29,14 @@
|
||||||
* ollama
|
* ollama
|
||||||
* DALL-E 2
|
* DALL-E 2
|
||||||
|
|
||||||
|
## 接入的搜索引擎
|
||||||
|
|
||||||
|
Google
|
||||||
|
|
||||||
|
Bing (TODO)
|
||||||
|
|
||||||
|
百度 (TODO)
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
|
该仓库为后端服务,前端项目见[langchain4j-aideepin-web](https://github.com/moyangzhan/langchain4j-aideepin-web)
|
||||||
|
@ -53,7 +63,7 @@ vue3+typescript+pnpm
|
||||||
|
|
||||||
* 创建数据库aideepin
|
* 创建数据库aideepin
|
||||||
* 执行docs/create.sql
|
* 执行docs/create.sql
|
||||||
* 填充各模型的配置
|
* 填充各模型的配置(至少设置一个)
|
||||||
|
|
||||||
openai的secretKey
|
openai的secretKey
|
||||||
|
|
||||||
|
@ -79,6 +89,15 @@ ollama的配置
|
||||||
update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting';
|
update adi_sys_config set value = '{"base_url":"my_ollama_base_url","models":["my model name,eg:tinydolphin"]}' where name = 'ollama_setting';
|
||||||
```
|
```
|
||||||
|
|
||||||
|
* 填充搜索引擎的配置
|
||||||
|
|
||||||
|
Google的配置
|
||||||
|
|
||||||
|
```
|
||||||
|
update adi_sys_config set value = '{"url":"https://www.googleapis.com/customsearch/v1","key":"my key from cloud.google.com","cx":"my cx from programmablesearchengine.google.com"}' where name = 'google_setting';
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
**b. 修改配置文件**
|
**b. 修改配置文件**
|
||||||
|
|
||||||
* postgresql: application-[dev|prod].xml中的spring.datasource
|
* postgresql: application-[dev|prod].xml中的spring.datasource
|
||||||
|
@ -122,23 +141,30 @@ docker run -d \
|
||||||
|
|
||||||
## 待办:
|
## 待办:
|
||||||
|
|
||||||
* AI搜索
|
增强RAG
|
||||||
* 增强RAG
|
|
||||||
|
增加搜索引擎(BING、百度)
|
||||||
|
|
||||||
## 截图
|
## 截图
|
||||||
|
|
||||||
**AI聊天:**
|
**AI聊天:**
|
||||||
![1691583184761](image/README/1691583184761.png)
|
![1691583184761](image/README/1691583184761.png)
|
||||||
|
|
||||||
![1691583124744](image/README/1691583124744.png "AI绘图")
|
**AI画图:**
|
||||||
|
|
||||||
![1691583329105](image/README/1691583329105.png "token统计")
|
![1691583124744](image/README/1691583124744.png "AI绘图")
|
||||||
|
|
||||||
**知识库:**
|
**知识库:**
|
||||||
![kbindex](image/README/kbidx.png)
|
![kbindex](image/README/kbidx.png)
|
||||||
|
|
||||||
![kb01](image/README/kb01.png)
|
![kb01](image/README/kb01.png)
|
||||||
|
|
||||||
|
**向量化:**
|
||||||
|
|
||||||
![kb02](image/README/kb02.png)
|
![kb02](image/README/kb02.png)
|
||||||
|
|
||||||
![kb03](image/README/kb03.png)
|
![kb03](image/README/kb03.png)
|
||||||
|
|
||||||
|
**额度统计:**
|
||||||
|
|
||||||
|
![1691583329105](https://file+.vscode-resource.vscode-cdn.net/e%3A/WORKSPACE/aideepin/image/README/1691583329105.png "token统计")
|
||||||
|
|
|
@ -11,6 +11,9 @@ spring:
|
||||||
name: AiDeepIn
|
name: AiDeepIn
|
||||||
profiles:
|
profiles:
|
||||||
active: dev
|
active: dev
|
||||||
|
mvc:
|
||||||
|
async:
|
||||||
|
request-timeout: 60000
|
||||||
jackson:
|
jackson:
|
||||||
date-format: "yyyy-MM-dd HH:mm:ss"
|
date-format: "yyyy-MM-dd HH:mm:ss"
|
||||||
time-zone: "GMT+8"
|
time-zone: "GMT+8"
|
||||||
|
|
|
@ -3,7 +3,9 @@ package com.moyz.adi.chat.controller;
|
||||||
import com.moyz.adi.common.dto.LoginReq;
|
import com.moyz.adi.common.dto.LoginReq;
|
||||||
import com.moyz.adi.common.dto.LoginResp;
|
import com.moyz.adi.common.dto.LoginResp;
|
||||||
import com.moyz.adi.common.dto.RegisterReq;
|
import com.moyz.adi.common.dto.RegisterReq;
|
||||||
|
import com.moyz.adi.common.searchengine.SearchEngineContext;
|
||||||
import com.moyz.adi.common.service.UserService;
|
import com.moyz.adi.common.service.UserService;
|
||||||
|
import com.moyz.adi.common.vo.SearchEngineInfo;
|
||||||
import com.ramostear.captcha.HappyCaptcha;
|
import com.ramostear.captcha.HappyCaptcha;
|
||||||
import com.ramostear.captcha.support.CaptchaType;
|
import com.ramostear.captcha.support.CaptchaType;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
|
@ -23,6 +25,8 @@ import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URLEncoder;
|
import java.net.URLEncoder;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.springframework.http.HttpHeaders.AUTHORIZATION;
|
import static org.springframework.http.HttpHeaders.AUTHORIZATION;
|
||||||
|
|
||||||
|
@ -123,4 +127,9 @@ public class AuthController {
|
||||||
happyCaptcha.output();
|
happyCaptcha.output();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "Search engine list")
|
||||||
|
@GetMapping(value = "/search-engine/list")
|
||||||
|
public List<SearchEngineInfo> engines() {
|
||||||
|
return SearchEngineContext.NAME_TO_ENGINE.values().stream().collect(Collectors.toList());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class KnowledgeBaseController {
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@PostMapping("/star/{uuid}")
|
@PostMapping("/star/{kbUuid}")
|
||||||
public boolean star(@PathVariable String kbUuid) {
|
public boolean star(@PathVariable String kbUuid) {
|
||||||
return knowledgeBaseService.star(kbUuid);
|
return knowledgeBaseService.star(kbUuid);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,6 @@ public class KnowledgeBaseQAController {
|
||||||
|
|
||||||
@PostMapping("/record/del/{uuid}")
|
@PostMapping("/record/del/{uuid}")
|
||||||
public boolean recordDel(@PathVariable String uuid) {
|
public boolean recordDel(@PathVariable String uuid) {
|
||||||
return knowledgeBaseQaRecordService.softDelele(uuid);
|
return knowledgeBaseQaRecordService.softDelete(uuid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package com.moyz.adi.chat.controller;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.dto.AiSearchReq;
|
||||||
|
import com.moyz.adi.common.service.SearchService;
|
||||||
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
@Tag(name = "AI search controller")
|
||||||
|
@RequestMapping("/ai-search/")
|
||||||
|
@RestController
|
||||||
|
public class SearchController {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private SearchService searchService;
|
||||||
|
|
||||||
|
@Operation(summary = "sse process")
|
||||||
|
@PostMapping(value = "/process", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
|
public SseEmitter sseAsk(@RequestBody @Validated AiSearchReq req) {
|
||||||
|
return searchService.search(req.isBriefSearch(), req.getSearchText(), req.getEngineName(), req.getModelName());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package com.moyz.adi.chat.controller;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.dto.AiSearchResp;
|
||||||
|
import com.moyz.adi.common.service.AiSearchRecordService;
|
||||||
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
@Tag(name = "Ai search record controller")
|
||||||
|
@RequestMapping("/ai-search-record/")
|
||||||
|
@RestController
|
||||||
|
public class SearchRecordController {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private AiSearchRecordService aiSearchRecordService;
|
||||||
|
|
||||||
|
@Operation(summary = "List by max id")
|
||||||
|
@GetMapping(value = "/list")
|
||||||
|
public AiSearchResp list(@RequestParam(defaultValue = "0") Long maxId, String keyword) {
|
||||||
|
return aiSearchRecordService.listByMaxId(maxId, keyword);
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/del/{uuid}")
|
||||||
|
public boolean recordDel(@PathVariable String uuid) {
|
||||||
|
return aiSearchRecordService.softDelete(uuid);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
package com.moyz.adi.common.base;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.apache.ibatis.type.BaseTypeHandler;
|
||||||
|
import org.apache.ibatis.type.JdbcType;
|
||||||
|
import org.apache.ibatis.type.MappedJdbcTypes;
|
||||||
|
import org.apache.ibatis.type.MappedTypes;
|
||||||
|
import org.postgresql.util.PGobject;
|
||||||
|
|
||||||
|
import java.sql.CallableStatement;
|
||||||
|
import java.sql.PreparedStatement;
|
||||||
|
import java.sql.ResultSet;
|
||||||
|
import java.sql.SQLException;
|
||||||
|
|
||||||
|
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
|
||||||
|
@MappedTypes({JsonNode.class})
|
||||||
|
public class JsonNodeTypeHandler extends BaseTypeHandler<JsonNode> {
|
||||||
|
|
||||||
|
private static final ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNonNullParameter(PreparedStatement ps, int i, JsonNode parameter, JdbcType jdbcType)
|
||||||
|
throws SQLException {
|
||||||
|
PGobject jsonObject = new PGobject();
|
||||||
|
jsonObject.setType("jsonb");
|
||||||
|
try {
|
||||||
|
jsonObject.setValue(parameter.toString());
|
||||||
|
ps.setObject(i, jsonObject);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JsonNode getNullableResult(ResultSet rs, String columnName) throws SQLException {
|
||||||
|
String jsonSource = rs.getString(columnName);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return objectMapper.readTree(jsonSource);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JsonNode getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
|
||||||
|
String jsonSource = rs.getString(columnIndex);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return objectMapper.readTree(jsonSource);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JsonNode getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
|
||||||
|
String jsonSource = cs.getString(columnIndex);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return objectMapper.readTree(jsonSource);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package com.moyz.adi.common.base;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.dto.SearchEngineResp;
|
||||||
|
import com.moyz.adi.common.util.JsonUtil;
|
||||||
|
import org.apache.ibatis.type.BaseTypeHandler;
|
||||||
|
import org.apache.ibatis.type.JdbcType;
|
||||||
|
import org.apache.ibatis.type.MappedJdbcTypes;
|
||||||
|
import org.apache.ibatis.type.MappedTypes;
|
||||||
|
import org.postgresql.util.PGobject;
|
||||||
|
|
||||||
|
import java.sql.CallableStatement;
|
||||||
|
import java.sql.PreparedStatement;
|
||||||
|
import java.sql.ResultSet;
|
||||||
|
import java.sql.SQLException;
|
||||||
|
|
||||||
|
@MappedJdbcTypes({JdbcType.JAVA_OBJECT})
|
||||||
|
@MappedTypes({SearchEngineResp.class})
|
||||||
|
public class SearchEngineRespTypeHandler extends BaseTypeHandler<SearchEngineResp> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNonNullParameter(PreparedStatement ps, int i, SearchEngineResp parameter, JdbcType jdbcType) {
|
||||||
|
PGobject jsonObject = new PGobject();
|
||||||
|
jsonObject.setType("jsonb");
|
||||||
|
try {
|
||||||
|
jsonObject.setValue(JsonUtil.toJson(parameter));
|
||||||
|
ps.setObject(i, jsonObject);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchEngineResp getNullableResult(ResultSet rs, String columnName) throws SQLException {
|
||||||
|
String jsonSource = rs.getString(columnName);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchEngineResp getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
|
||||||
|
String jsonSource = rs.getString(columnIndex);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchEngineResp getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
|
||||||
|
String jsonSource = cs.getString(columnIndex);
|
||||||
|
if (jsonSource != null) {
|
||||||
|
try {
|
||||||
|
return JsonUtil.fromJson(jsonSource, SearchEngineResp.class);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,11 +10,14 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
|
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
|
||||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
||||||
import com.google.common.collect.Lists;
|
import com.moyz.adi.common.base.SearchEngineRespTypeHandler;
|
||||||
|
import com.moyz.adi.common.dto.SearchEngineResp;
|
||||||
|
import com.moyz.adi.common.service.RAGService;
|
||||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
||||||
import com.pgvector.PGvector;
|
import com.pgvector.PGvector;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.ibatis.session.SqlSessionFactory;
|
import org.apache.ibatis.session.SqlSessionFactory;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
|
@ -32,6 +35,15 @@ import javax.sql.DataSource;
|
||||||
@Configuration
|
@Configuration
|
||||||
public class BeanConfig {
|
public class BeanConfig {
|
||||||
|
|
||||||
|
@Value("${spring.datasource.url}")
|
||||||
|
private String dataBaseUrl;
|
||||||
|
|
||||||
|
@Value("${spring.datasource.username}")
|
||||||
|
private String dataBaseUserName;
|
||||||
|
|
||||||
|
@Value("${spring.datasource.password}")
|
||||||
|
private String dataBasePassword;
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
public RestTemplate restTemplate() {
|
public RestTemplate restTemplate() {
|
||||||
log.info("Configuration:create restTemplate");
|
log.info("Configuration:create restTemplate");
|
||||||
|
@ -42,7 +54,7 @@ public class BeanConfig {
|
||||||
requestFactory.setReadTimeout(60000);
|
requestFactory.setReadTimeout(60000);
|
||||||
RestTemplate restTemplate = new RestTemplate();
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
// 注册LOG拦截器
|
// 注册LOG拦截器
|
||||||
restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor()));
|
// restTemplate.setInterceptors(Lists.newArrayList(new LogClientHttpRequestInterceptor()));
|
||||||
restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(requestFactory));
|
restTemplate.setRequestFactory(new BufferingClientHttpRequestFactory(requestFactory));
|
||||||
|
|
||||||
return restTemplate;
|
return restTemplate;
|
||||||
|
@ -95,12 +107,28 @@ public class BeanConfig {
|
||||||
bean.setMapperLocations(
|
bean.setMapperLocations(
|
||||||
new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml"));
|
new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml"));
|
||||||
MybatisConfiguration configuration = bean.getConfiguration();
|
MybatisConfiguration configuration = bean.getConfiguration();
|
||||||
if(null == configuration){
|
if (null == configuration) {
|
||||||
configuration = new MybatisConfiguration();
|
configuration = new MybatisConfiguration();
|
||||||
bean.setConfiguration(configuration);
|
bean.setConfiguration(configuration);
|
||||||
}
|
}
|
||||||
bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class);
|
bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class);
|
||||||
|
bean.getConfiguration().getTypeHandlerRegistry().register(SearchEngineResp.class, SearchEngineRespTypeHandler.class);
|
||||||
return bean.getObject();
|
return bean.getObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Bean(name = "kbRagService")
|
||||||
|
@Primary
|
||||||
|
public RAGService initKnowledgeBaseRAGService() {
|
||||||
|
RAGService ragService = new RAGService("adi_knowledge_base_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
|
||||||
|
ragService.init();
|
||||||
|
return ragService;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean(name = "searchRagService")
|
||||||
|
public RAGService initSearchRAGService() {
|
||||||
|
RAGService ragService = new RAGService("adi_ai_search_embedding", dataBaseUrl, dataBaseUserName, dataBasePassword);
|
||||||
|
ragService.init();
|
||||||
|
return ragService;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package com.moyz.adi.common.cosntant;
|
package com.moyz.adi.common.cosntant;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class AdiConstant {
|
public class AdiConstant {
|
||||||
|
|
||||||
public static final int DEFAULT_PAGE_SIZE = 1;
|
public static final int DEFAULT_PAGE_SIZE = 10;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 验证码id过期时间:1小时
|
* 验证码id过期时间:1小时
|
||||||
|
@ -53,6 +55,13 @@ public class AdiConstant {
|
||||||
|
|
||||||
public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024");
|
public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024");
|
||||||
|
|
||||||
|
public static final PromptTemplate PROMPT_TEMPLATE = PromptTemplate.from("""
|
||||||
|
根据以下已知信息:
|
||||||
|
{{information}}
|
||||||
|
尽可能准确地回答用户的问题,以下是用户的问题:
|
||||||
|
{{question}}
|
||||||
|
注意,回答的内容不能让用户感知到已知信息的存在
|
||||||
|
""");
|
||||||
|
|
||||||
public static class GenerateImage {
|
public static class GenerateImage {
|
||||||
public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1;
|
public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1;
|
||||||
|
@ -64,11 +73,21 @@ public class AdiConstant {
|
||||||
public static final int STATUS_SUCCESS = 3;
|
public static final int STATUS_SUCCESS = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class EmbeddingMetadataKey {
|
||||||
|
public static final String KB_UUID = "kb_uuid";
|
||||||
|
public static final String KB_ITEM_UUID = "kb_item_uuid";
|
||||||
|
public static final String ENGINE_NAME = "engine_name";
|
||||||
|
public static final String SEARCH_UUID = "search_uuid";
|
||||||
|
}
|
||||||
|
|
||||||
public static class SysConfigKey {
|
public static class SysConfigKey {
|
||||||
public static final String OPENAI_SETTING = "openai_setting";
|
public static final String OPENAI_SETTING = "openai_setting";
|
||||||
public static final String DASHSCOPE_SETTING = "dashscope_setting";
|
public static final String DASHSCOPE_SETTING = "dashscope_setting";
|
||||||
public static final String QIANFAN_SETTING = "qianfan_setting";
|
public static final String QIANFAN_SETTING = "qianfan_setting";
|
||||||
public static final String OLLAMA_SETTING = "ollama_setting";
|
public static final String OLLAMA_SETTING = "ollama_setting";
|
||||||
|
public static final String GOOGLE_SETTING = "google_setting";
|
||||||
|
public static final String BING_SETTING = "bing_setting";
|
||||||
|
public static final String BAIDU_SETTING = "baidu_setting";
|
||||||
public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit";
|
public static final String REQUEST_TEXT_RATE_LIMIT = "request_text_rate_limit";
|
||||||
public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit";
|
public static final String REQUEST_IMAGE_RATE_LIMIT = "request_image_rate_limit";
|
||||||
public static final String CONVERSATION_MAX_NUM = "conversation_max_num";
|
public static final String CONVERSATION_MAX_NUM = "conversation_max_num";
|
||||||
|
@ -78,8 +97,25 @@ public class AdiConstant {
|
||||||
public static final String QUOTA_BY_REQUEST_MONTHLY = "quota_by_request_monthly";
|
public static final String QUOTA_BY_REQUEST_MONTHLY = "quota_by_request_monthly";
|
||||||
public static final String QUOTA_BY_IMAGE_DAILY = "quota_by_image_daily";
|
public static final String QUOTA_BY_IMAGE_DAILY = "quota_by_image_daily";
|
||||||
public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly";
|
public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly";
|
||||||
|
public static final String QUOTA_BY_QA_ASK_DAILY = "quota_by_qa_ask_daily";
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
|
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
|
||||||
|
|
||||||
|
public static class SearchEngineName {
|
||||||
|
public static final String GOOGLE = "google";
|
||||||
|
public static final String BING = "bing";
|
||||||
|
public static final String BAIDU = "baidu";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class SSEEventName {
|
||||||
|
public static final String START = "[START]";
|
||||||
|
public static final String DONE = "[DONE]";
|
||||||
|
public static final String ERROR = "[ERROR]";
|
||||||
|
|
||||||
|
public static final String AI_SEARCH_SOURCE_LINKS = "[SOURCE_LINKS]";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final int RAG_TYPE_KB = 1;
|
||||||
|
public static final int RAG_TYPE_SEARCH = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Builder
|
||||||
|
@Data
|
||||||
|
public class AiSearchRecordResp {
|
||||||
|
|
||||||
|
private String uuid;
|
||||||
|
|
||||||
|
private String question;
|
||||||
|
|
||||||
|
private SearchEngineResp searchEngineResp;
|
||||||
|
|
||||||
|
private String prompt;
|
||||||
|
|
||||||
|
private Integer promptTokens;
|
||||||
|
|
||||||
|
private String answer;
|
||||||
|
|
||||||
|
private Integer answerTokens;
|
||||||
|
|
||||||
|
private String userUuid;
|
||||||
|
|
||||||
|
private LocalDateTime createTime;
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import jakarta.validation.constraints.NotBlank;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
|
||||||
|
@Validated
|
||||||
|
@Data
|
||||||
|
public class AiSearchReq {
|
||||||
|
|
||||||
|
@NotBlank
|
||||||
|
private String searchText;
|
||||||
|
|
||||||
|
private String engineName;
|
||||||
|
|
||||||
|
private String modelName;
|
||||||
|
|
||||||
|
private boolean briefSearch;
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class AiSearchResp {
|
||||||
|
|
||||||
|
private Long minId;
|
||||||
|
|
||||||
|
private List<AiSearchRecordResp> records;
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class GoogleSearchError {
|
||||||
|
private Integer code;
|
||||||
|
private String message;
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class GoogleSearchResp {
|
||||||
|
private String kind;
|
||||||
|
private Queries queries;
|
||||||
|
private SearchInformation searchInformation;
|
||||||
|
private List<Item> items;
|
||||||
|
private GoogleSearchError error;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Queries {
|
||||||
|
private Request[] request;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Request {
|
||||||
|
private String title;
|
||||||
|
private String totalResults;
|
||||||
|
private String searchTerms;
|
||||||
|
private Integer count;
|
||||||
|
private Integer startIndex;
|
||||||
|
private String inputEncoding;
|
||||||
|
private String outputEncoding;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class SearchInformation {
|
||||||
|
private double searchTime;
|
||||||
|
private String formattedSearchTime;
|
||||||
|
private String totalResults;
|
||||||
|
private String formattedTotalResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Item {
|
||||||
|
private String kind;
|
||||||
|
private String title;
|
||||||
|
private String htmlTitle;
|
||||||
|
private String link;
|
||||||
|
private String snippet;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.experimental.Accessors;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Accessors(chain = true)
|
||||||
|
public class SearchEngineResp {
|
||||||
|
private List<SearchResultItem> items;
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SearchResult {
|
||||||
|
private String errorMessage;
|
||||||
|
private List<SearchResultItem> items;
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package com.moyz.adi.common.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SearchResultItem {
|
||||||
|
private String title;
|
||||||
|
private String link;
|
||||||
|
private String snippet;
|
||||||
|
private String content;
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package com.moyz.adi.common.entity;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableField;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import com.moyz.adi.common.base.SearchEngineRespTypeHandler;
|
||||||
|
import com.moyz.adi.common.dto.SearchEngineResp;
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.apache.ibatis.type.JdbcType;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@TableName("adi_ai_search_record")
|
||||||
|
@Schema(title = "AiSearchRecord对象", description = "AI搜索记录表")
|
||||||
|
public class AiSearchRecord extends BaseEntity {
|
||||||
|
|
||||||
|
@TableField("uuid")
|
||||||
|
private String uuid;
|
||||||
|
|
||||||
|
@Schema(title = "问题")
|
||||||
|
@TableField("question")
|
||||||
|
private String question;
|
||||||
|
|
||||||
|
@Schema(title = "Search engine's response content")
|
||||||
|
@TableField(value = "search_engine_response", jdbcType = JdbcType.JAVA_OBJECT, typeHandler = SearchEngineRespTypeHandler.class)
|
||||||
|
private SearchEngineResp searchEngineResp;
|
||||||
|
|
||||||
|
@Schema(title = "最终提供给LLM的提示词")
|
||||||
|
@TableField("prompt")
|
||||||
|
private String prompt;
|
||||||
|
|
||||||
|
@Schema(title = "提供给LLM的提示词所消耗的token数量")
|
||||||
|
@TableField("prompt_tokens")
|
||||||
|
private Integer promptTokens;
|
||||||
|
|
||||||
|
@Schema(title = "答案")
|
||||||
|
@TableField("answer")
|
||||||
|
private String answer;
|
||||||
|
|
||||||
|
@Schema(title = "答案消耗的token")
|
||||||
|
@TableField("answer_tokens")
|
||||||
|
private Integer answerTokens;
|
||||||
|
|
||||||
|
@Schema(title = "提问用户uuid")
|
||||||
|
@TableField("user_uuid")
|
||||||
|
private String userUuid;
|
||||||
|
|
||||||
|
@Schema(title = "提问用户id")
|
||||||
|
@TableField("user_id")
|
||||||
|
private Long userId;
|
||||||
|
}
|
|
@ -34,7 +34,7 @@ public enum ErrorEnum {
|
||||||
B_MESSAGE_NOT_FOUND("B0008", "消息不存在"),
|
B_MESSAGE_NOT_FOUND("B0008", "消息不存在"),
|
||||||
B_LLM_SERVICE_DISABLED("B0009", "LLM服务不可用"),
|
B_LLM_SERVICE_DISABLED("B0009", "LLM服务不可用"),
|
||||||
B_KNOWLEDGE_BASE_IS_EMPTY("B0010", "知识库内容为空"),
|
B_KNOWLEDGE_BASE_IS_EMPTY("B0010", "知识库内容为空"),
|
||||||
B_KNOWLEDGE_BASE_NO_ANSWER("B0011", "[无答案]")
|
B_NO_ANSWER("B0011", "[无答案]")
|
||||||
;
|
;
|
||||||
|
|
||||||
private String code;
|
private String code;
|
||||||
|
|
|
@ -17,10 +17,10 @@ public class QuotaHelper {
|
||||||
private UserDayCostService userDayCostService;
|
private UserDayCostService userDayCostService;
|
||||||
|
|
||||||
public ErrorEnum checkTextQuota(User user) {
|
public ErrorEnum checkTextQuota(User user) {
|
||||||
if (StringUtils.isNotBlank(user.getSecretKey())) {
|
// if (StringUtils.isNotBlank(user.getSecretKey())) {
|
||||||
log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId());
|
// log.info("Custom secret key,dont need to check text request quota,userId:{}", user.getId());
|
||||||
return null;
|
// return null;
|
||||||
}
|
// }
|
||||||
int userQuotaByTokenDay = user.getQuotaByTokenDaily();
|
int userQuotaByTokenDay = user.getQuotaByTokenDaily();
|
||||||
int userQuotaByTokenMonth = user.getQuotaByTokenMonthly();
|
int userQuotaByTokenMonth = user.getQuotaByTokenMonthly();
|
||||||
int userQuotaByRequestDay = user.getQuotaByRequestDaily();
|
int userQuotaByRequestDay = user.getQuotaByRequestDaily();
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package com.moyz.adi.common.helper;
|
package com.moyz.adi.common.helper;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||||
import com.moyz.adi.common.entity.User;
|
import com.moyz.adi.common.entity.User;
|
||||||
import com.moyz.adi.common.interfaces.TriConsumer;
|
import com.moyz.adi.common.interfaces.TriConsumer;
|
||||||
|
@ -28,35 +29,45 @@ public class SSEEmitterHelper {
|
||||||
@Resource
|
@Resource
|
||||||
private RateLimitHelper rateLimitHelper;
|
private RateLimitHelper rateLimitHelper;
|
||||||
|
|
||||||
public void process(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
|
||||||
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
|
//Check: rate limit
|
||||||
|
|
||||||
//rate limit by system
|
|
||||||
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
||||||
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
||||||
sendErrorMsg(sseEmitter, "访问太过频繁");
|
sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check: If still waiting response
|
//Check: If still waiting response
|
||||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||||
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
||||||
if (StringUtils.isNotBlank(askingVal)) {
|
if (StringUtils.isNotBlank(askingVal)) {
|
||||||
sendErrorMsg(sseEmitter, "正在回复中...");
|
sendErrorAndComplete(user.getId(), sseEmitter, "正在回复中...");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void startSse(User user, SseEmitter sseEmitter) {
|
||||||
|
|
||||||
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||||
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
stringRedisTemplate.opsForValue().set(askingKey, "1", 15, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
||||||
|
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
||||||
try {
|
try {
|
||||||
sseEmitter.send(SseEmitter.event().name("[START]"));
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.error("error", e);
|
log.error("error", e);
|
||||||
sseEmitter.completeWithError(e);
|
sseEmitter.completeWithError(e);
|
||||||
stringRedisTemplate.delete(askingKey);
|
stringRedisTemplate.delete(askingKey);
|
||||||
return;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG);
|
|
||||||
|
|
||||||
|
public void processAndPushToModel(User user, SseAskParams sseAskParams, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||||
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||||
|
|
||||||
|
SseEmitter sseEmitter = sseAskParams.getSseEmitter();
|
||||||
sseEmitter.onCompletion(() -> {
|
sseEmitter.onCompletion(() -> {
|
||||||
log.info("response complete,uid:{}", user.getId());
|
log.info("response complete,uid:{}", user.getId());
|
||||||
});
|
});
|
||||||
|
@ -65,7 +76,7 @@ public class SSEEmitterHelper {
|
||||||
throwable -> {
|
throwable -> {
|
||||||
try {
|
try {
|
||||||
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
log.error("sseEmitter error,uid:{},on error:{}", user.getId(), throwable);
|
||||||
sseEmitter.send(SseEmitter.event().name("[ERROR]").data(throwable.getMessage()));
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(throwable.getMessage()));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.error("error", e);
|
log.error("error", e);
|
||||||
} finally {
|
} finally {
|
||||||
|
@ -84,22 +95,29 @@ public class SSEEmitterHelper {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public void sendAndComplete(SseEmitter sseEmitter, String msg){
|
public void sendAndComplete(long userId, SseEmitter sseEmitter, String msg) {
|
||||||
try {
|
try {
|
||||||
sseEmitter.send(SseEmitter.event().name("[START]"));
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.START));
|
||||||
sseEmitter.send(SseEmitter.event().name("[DONE]").data(msg));
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.DONE).data(msg));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
sseEmitter.complete();
|
sseEmitter.complete();
|
||||||
|
delSseRequesting(userId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void sendErrorMsg(SseEmitter sseEmitter, String errorMsg) {
|
public void sendErrorAndComplete(long userId, SseEmitter sseEmitter, String errorMsg) {
|
||||||
try {
|
try {
|
||||||
sseEmitter.send(SseEmitter.event().name("[ERROR]").data(errorMsg));
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(errorMsg));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
sseEmitter.complete();
|
sseEmitter.complete();
|
||||||
|
delSseRequesting(userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void delSseRequesting(long userId) {
|
||||||
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, userId);
|
||||||
|
stringRedisTemplate.delete(askingKey);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,13 +39,17 @@ public abstract class AbstractImageModelService<T> {
|
||||||
|
|
||||||
protected ImageModel imageModel;
|
protected ImageModel imageModel;
|
||||||
|
|
||||||
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz, Proxy proxy){
|
public AbstractImageModelService(String modelName, String settingName, Class<T> clazz) {
|
||||||
this.modelName = modelName;
|
this.modelName = modelName;
|
||||||
this.proxy = proxy;
|
|
||||||
String st = LocalCache.CONFIGS.get(settingName);
|
String st = LocalCache.CONFIGS.get(settingName);
|
||||||
setting = JsonUtil.fromJson(st, clazz);
|
setting = JsonUtil.fromJson(st, clazz);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AbstractImageModelService setProxy(Proxy proxy) {
|
||||||
|
this.proxy = proxy;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public ImageModel getImageModel(User user, String size) {
|
public ImageModel getImageModel(User user, String size) {
|
||||||
if (null != imageModel) {
|
if (null != imageModel) {
|
||||||
return imageModel;
|
return imageModel;
|
||||||
|
@ -56,6 +60,7 @@ public abstract class AbstractImageModelService<T> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检测该service是否可用(不可用的情况通过是没有配置key)
|
* 检测该service是否可用(不可用的情况通过是没有配置key)
|
||||||
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public abstract boolean isEnabled();
|
public abstract boolean isEnabled();
|
||||||
|
|
|
@ -36,13 +36,17 @@ public abstract class AbstractLLMService<T> {
|
||||||
protected StreamingChatLanguageModel streamingChatLanguageModel;
|
protected StreamingChatLanguageModel streamingChatLanguageModel;
|
||||||
protected ChatLanguageModel chatLanguageModel;
|
protected ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
public AbstractLLMService(String modelName, String settingName, Class<T> clazz, Proxy proxy) {
|
public AbstractLLMService(String modelName, String settingName, Class<T> clazz) {
|
||||||
this.modelName = modelName;
|
this.modelName = modelName;
|
||||||
this.proxy = proxy;
|
|
||||||
String st = LocalCache.CONFIGS.get(settingName);
|
String st = LocalCache.CONFIGS.get(settingName);
|
||||||
setting = JsonUtil.fromJson(st, clazz);
|
setting = JsonUtil.fromJson(st, clazz);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AbstractLLMService setProxy(Proxy proxy) {
|
||||||
|
this.proxy = proxy;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检测该service是否可用(不可用的情况通常是没有配置key)
|
* 检测该service是否可用(不可用的情况通常是没有配置key)
|
||||||
*
|
*
|
||||||
|
@ -73,7 +77,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
protected abstract String parseError(Object error);
|
protected abstract String parseError(Object error);
|
||||||
|
|
||||||
public Response<AiMessage> chat(ChatMessage chatMessage) {
|
public Response<AiMessage> chat(ChatMessage chatMessage) {
|
||||||
if(!isEnabled()){
|
if (!isEnabled()) {
|
||||||
log.error("llm service is disabled");
|
log.error("llm service is disabled");
|
||||||
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
||||||
}
|
}
|
||||||
|
@ -81,7 +85,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
public void sseChat(SseAskParams params, TriConsumer<String, PromptMeta, AnswerMeta> consumer) {
|
||||||
if(!isEnabled()){
|
if (!isEnabled()) {
|
||||||
log.error("llm service is disabled");
|
log.error("llm service is disabled");
|
||||||
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
throw new BaseException(B_LLM_SERVICE_DISABLED);
|
||||||
}
|
}
|
||||||
|
@ -131,7 +135,7 @@ public abstract class AbstractLLMService<T> {
|
||||||
log.error("stream error", error);
|
log.error("stream error", error);
|
||||||
try {
|
try {
|
||||||
String errorMsg = parseError(error);
|
String errorMsg = parseError(error);
|
||||||
if(StringUtils.isBlank(errorMsg)){
|
if (StringUtils.isBlank(errorMsg)) {
|
||||||
errorMsg = error.getMessage();
|
errorMsg = error.getMessage();
|
||||||
}
|
}
|
||||||
params.getSseEmitter().send(SseEmitter.event().name("[ERROR]").data(errorMsg));
|
params.getSseEmitter().send(SseEmitter.event().name("[ERROR]").data(errorMsg));
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package com.moyz.adi.common.interfaces;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.dto.SearchResult;
|
||||||
|
import com.moyz.adi.common.util.JsonUtil;
|
||||||
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
|
import com.moyz.adi.common.util.SpringUtil;
|
||||||
|
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
import java.net.Proxy;
|
||||||
|
|
||||||
|
public abstract class AbstractSearchEngine<T> {
|
||||||
|
|
||||||
|
protected String engineName;
|
||||||
|
|
||||||
|
protected Proxy proxy;
|
||||||
|
|
||||||
|
public AbstractSearchEngine(String engineName, String settingName, Class<T> clazz) {
|
||||||
|
this.engineName = engineName;
|
||||||
|
String st = LocalCache.CONFIGS.get(settingName);
|
||||||
|
setting = JsonUtil.fromJson(st, clazz);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected T setting;
|
||||||
|
|
||||||
|
public abstract boolean isEnabled();
|
||||||
|
|
||||||
|
public abstract SearchResult search(String searchTxt);
|
||||||
|
|
||||||
|
public AbstractSearchEngine setProxy(Proxy proxy) {
|
||||||
|
this.proxy = proxy;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected RestTemplate getRestTemplate() {
|
||||||
|
RestTemplate restTemplate = SpringUtil.getBean(RestTemplate.class);
|
||||||
|
if (null != proxy) {
|
||||||
|
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
|
||||||
|
requestFactory.setProxy(proxy);
|
||||||
|
restTemplate.setRequestFactory(requestFactory);
|
||||||
|
}
|
||||||
|
return restTemplate;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package com.moyz.adi.common.interfaces;
|
||||||
|
|
||||||
|
public abstract class AbstractSearchService {
|
||||||
|
|
||||||
|
public abstract boolean isEnabled();
|
||||||
|
|
||||||
|
public abstract void briefSearch(String question);
|
||||||
|
|
||||||
|
public abstract void detailSearch(String question);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package com.moyz.adi.common.mapper;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||||
|
import com.moyz.adi.common.entity.AiSearchRecord;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
|
||||||
|
@Mapper
|
||||||
|
public interface AiSearchRecordMapper extends BaseMapper<AiSearchRecord> {
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package com.moyz.adi.common.searchengine;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
|
import com.moyz.adi.common.dto.GoogleSearchResp;
|
||||||
|
import com.moyz.adi.common.dto.SearchResult;
|
||||||
|
import com.moyz.adi.common.dto.SearchResultItem;
|
||||||
|
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
|
||||||
|
import com.moyz.adi.common.util.MPPageUtil;
|
||||||
|
import com.moyz.adi.common.vo.GoogleSetting;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
|
||||||
|
import java.text.MessageFormat;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.GOOGLE_SETTING;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class GoogleSearchEngine extends AbstractSearchEngine<GoogleSetting> {
|
||||||
|
|
||||||
|
public GoogleSearchEngine() {
|
||||||
|
super(AdiConstant.SearchEngineName.GOOGLE, GOOGLE_SETTING, GoogleSetting.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEnabled() {
|
||||||
|
return StringUtils.isNoneBlank(setting.getKey(), setting.getCx());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SearchResult search(String searchTxt) {
|
||||||
|
SearchResult result = new SearchResult();
|
||||||
|
List<SearchResultItem> items = new ArrayList<>();
|
||||||
|
try {
|
||||||
|
ResponseEntity<GoogleSearchResp> resp = getRestTemplate().getForEntity(MessageFormat.format("{0}?key={1}&cx={2}&q={3}", setting.getUrl(), setting.getKey(), setting.getCx(), searchTxt), GoogleSearchResp.class);
|
||||||
|
if (null != resp && HttpStatus.OK.isSameCodeAs(resp.getStatusCode())) {
|
||||||
|
GoogleSearchResp googleSearchResp = resp.getBody();
|
||||||
|
if (null != googleSearchResp.getError()) {
|
||||||
|
log.error("google search error,code:{},message:{}", googleSearchResp.getError().getCode(), googleSearchResp.getError().getMessage());
|
||||||
|
result.setErrorMessage(googleSearchResp.getError().getMessage());
|
||||||
|
} else {
|
||||||
|
log.info("google response:{}", resp);
|
||||||
|
items = MPPageUtil.convertTo(googleSearchResp.getItems(), SearchResultItem.class);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("google search error", e);
|
||||||
|
}
|
||||||
|
result.setItems(items);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
package com.moyz.adi.common.searchengine;
|
||||||
|
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
|
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
|
||||||
|
import com.moyz.adi.common.vo.SearchEngineInfo;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search engine context. strategy design model
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SearchEngineContext {
|
||||||
|
public static final Map<String, SearchEngineInfo> NAME_TO_ENGINE = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
private AbstractSearchEngine iSearchEngine;
|
||||||
|
|
||||||
|
public SearchEngineContext(String searchEngineName) {
|
||||||
|
if (null == NAME_TO_ENGINE.get(searchEngineName)) {
|
||||||
|
log.warn("︿︿︿ Can not find {}, use the default engine GOOGLE ︿︿︿", searchEngineName);
|
||||||
|
iSearchEngine = NAME_TO_ENGINE.get(AdiConstant.SearchEngineName.GOOGLE).getEngine();
|
||||||
|
} else {
|
||||||
|
iSearchEngine = NAME_TO_ENGINE.get(searchEngineName).getEngine();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void addEngine(String engineName, AbstractSearchEngine searchEngine) {
|
||||||
|
SearchEngineInfo info = new SearchEngineInfo();
|
||||||
|
info.setName(engineName);
|
||||||
|
info.setEnable(searchEngine.isEnabled());
|
||||||
|
info.setEngine(searchEngine);
|
||||||
|
NAME_TO_ENGINE.put(engineName, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AbstractSearchEngine getEngine() {
|
||||||
|
return iSearchEngine;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
package com.moyz.adi.common.service;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
|
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||||
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
|
import com.moyz.adi.common.dto.AiSearchRecordResp;
|
||||||
|
import com.moyz.adi.common.dto.AiSearchResp;
|
||||||
|
import com.moyz.adi.common.dto.SearchEngineResp;
|
||||||
|
import com.moyz.adi.common.entity.AiSearchRecord;
|
||||||
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
|
import com.moyz.adi.common.mapper.AiSearchRecordMapper;
|
||||||
|
import com.moyz.adi.common.util.BizPager;
|
||||||
|
import com.moyz.adi.common.util.MPPageUtil;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ai search
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class AiSearchRecordService extends ServiceImpl<AiSearchRecordMapper, AiSearchRecord> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List search records
|
||||||
|
*
|
||||||
|
* @param maxId Anchor id
|
||||||
|
* @param keyword user's question
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public AiSearchResp listByMaxId(Long maxId, String keyword) {
|
||||||
|
LambdaQueryWrapper<AiSearchRecord> wrapper = new LambdaQueryWrapper<>();
|
||||||
|
wrapper.eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId());
|
||||||
|
wrapper.eq(AiSearchRecord::getIsDeleted, false);
|
||||||
|
if (StringUtils.isNotBlank(keyword)) {
|
||||||
|
wrapper.like(AiSearchRecord::getQuestion, keyword);
|
||||||
|
}
|
||||||
|
AiSearchResp result = new AiSearchResp();
|
||||||
|
BizPager.listByMaxId(maxId, wrapper, this, AiSearchRecord::getId, (recordList, minId) -> {
|
||||||
|
List<AiSearchRecordResp> list = MPPageUtil.convertTo(recordList, AiSearchRecordResp.class);
|
||||||
|
list.forEach(item -> {
|
||||||
|
if(null == item.getSearchEngineResp()){
|
||||||
|
SearchEngineResp searchEngineResp = new SearchEngineResp();
|
||||||
|
searchEngineResp.setItems(new ArrayList<>());
|
||||||
|
item.setSearchEngineResp(searchEngineResp);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
result.setRecords(list);
|
||||||
|
result.setMinId(minId);
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean softDelete(String uuid) {
|
||||||
|
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||||
|
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||||
|
.eq(AiSearchRecord::getUuid, uuid)
|
||||||
|
.set(AiSearchRecord::getIsDeleted, true)
|
||||||
|
.update();
|
||||||
|
}
|
||||||
|
AiSearchRecord exist = ChainWrappers.lambdaQueryChain(baseMapper)
|
||||||
|
.eq(AiSearchRecord::getUuid, uuid)
|
||||||
|
.eq(AiSearchRecord::getUserId, ThreadContext.getCurrentUserId())
|
||||||
|
.one();
|
||||||
|
if (null == exist) {
|
||||||
|
throw new BaseException(A_DATA_NOT_FOUND);
|
||||||
|
}
|
||||||
|
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||||
|
.eq(AiSearchRecord::getId, exist.getId())
|
||||||
|
.set(AiSearchRecord::getIsDeleted, true)
|
||||||
|
.update();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -8,7 +8,6 @@ import com.moyz.adi.common.dto.AskReq;
|
||||||
import com.moyz.adi.common.entity.Conversation;
|
import com.moyz.adi.common.entity.Conversation;
|
||||||
import com.moyz.adi.common.entity.ConversationMessage;
|
import com.moyz.adi.common.entity.ConversationMessage;
|
||||||
import com.moyz.adi.common.entity.User;
|
import com.moyz.adi.common.entity.User;
|
||||||
import com.moyz.adi.common.entity.UserDayCost;
|
|
||||||
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
|
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
|
||||||
import com.moyz.adi.common.enums.ErrorEnum;
|
import com.moyz.adi.common.enums.ErrorEnum;
|
||||||
import com.moyz.adi.common.exception.BaseException;
|
import com.moyz.adi.common.exception.BaseException;
|
||||||
|
@ -16,8 +15,6 @@ import com.moyz.adi.common.helper.QuotaHelper;
|
||||||
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||||
import com.moyz.adi.common.mapper.ConversationMessageMapper;
|
import com.moyz.adi.common.mapper.ConversationMessageMapper;
|
||||||
import com.moyz.adi.common.util.LocalCache;
|
import com.moyz.adi.common.util.LocalCache;
|
||||||
import com.moyz.adi.common.util.LocalDateTimeUtil;
|
|
||||||
import com.moyz.adi.common.util.UserUtil;
|
|
||||||
import com.moyz.adi.common.vo.AnswerMeta;
|
import com.moyz.adi.common.vo.AnswerMeta;
|
||||||
import com.moyz.adi.common.vo.PromptMeta;
|
import com.moyz.adi.common.vo.PromptMeta;
|
||||||
import com.moyz.adi.common.vo.SseAskParams;
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
|
@ -66,11 +63,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
|
|
||||||
public SseEmitter sseAsk(AskReq askReq) {
|
public SseEmitter sseAsk(AskReq askReq) {
|
||||||
SseEmitter sseEmitter = new SseEmitter();
|
SseEmitter sseEmitter = new SseEmitter();
|
||||||
|
User user = ThreadContext.getCurrentUser();
|
||||||
|
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
sseEmitterHelper.startSse(user, sseEmitter);
|
||||||
_this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
|
_this.asyncCheckAndPushToClient(sseEmitter, ThreadContext.getCurrentUser(), askReq);
|
||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean check(SseEmitter sseEmitter, User user, AskReq askReq) {
|
private boolean checkConversation(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||||
try {
|
try {
|
||||||
|
|
||||||
//check 1: the conversation has been deleted
|
//check 1: the conversation has been deleted
|
||||||
|
@ -79,7 +81,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
.eq(Conversation::getIsDeleted, true)
|
.eq(Conversation::getIsDeleted, true)
|
||||||
.one();
|
.one();
|
||||||
if (null != delConv) {
|
if (null != delConv) {
|
||||||
sseEmitterHelper.sendErrorMsg(sseEmitter, "该对话已经删除");
|
sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, "该对话已经删除");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,14 +92,14 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
.count();
|
.count();
|
||||||
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
|
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
|
||||||
if (convsCount >= convsMax) {
|
if (convsCount >= convsMax) {
|
||||||
sseEmitterHelper.sendErrorMsg(sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, "对话数量已经达到上限,当前对话上限为:" + convsMax);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//check 3: current user's quota
|
//check 3: current user's quota
|
||||||
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
|
ErrorEnum errorMsg = quotaHelper.checkTextQuota(user);
|
||||||
if (null != errorMsg) {
|
if (null != errorMsg) {
|
||||||
sseEmitterHelper.sendErrorMsg(sseEmitter, errorMsg.getInfo());
|
sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, errorMsg.getInfo());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -112,7 +114,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
|
public void asyncCheckAndPushToClient(SseEmitter sseEmitter, User user, AskReq askReq) {
|
||||||
log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
|
log.info("asyncCheckAndPushToClient,userId:{}", user.getId());
|
||||||
//check business rules
|
//check business rules
|
||||||
if (!check(sseEmitter, user, askReq)) {
|
if (!checkConversation(sseEmitter, user, askReq)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +163,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sseEmitterHelper.process(user, sseAskParams, (response, questionMeta, answerMeta) -> {
|
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, questionMeta, answerMeta) -> {
|
||||||
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
|
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ import static com.moyz.adi.common.enums.ErrorEnum.B_LLM_SECRET_KEY_NOT_SET;
|
||||||
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
|
public class DashScopeLLMService extends AbstractLLMService<DashScopeSetting> {
|
||||||
|
|
||||||
public DashScopeLLMService(String modelName) {
|
public DashScopeLLMService(String modelName) {
|
||||||
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class, null);
|
super(modelName, AdiConstant.SysConfigKey.DASHSCOPE_SETTING, DashScopeSetting.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -3,6 +3,8 @@ package com.moyz.adi.common.service;
|
||||||
import com.moyz.adi.common.cosntant.AdiConstant;
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.helper.ImageModelContext;
|
import com.moyz.adi.common.helper.ImageModelContext;
|
||||||
import com.moyz.adi.common.helper.LLMContext;
|
import com.moyz.adi.common.helper.LLMContext;
|
||||||
|
import com.moyz.adi.common.searchengine.GoogleSearchEngine;
|
||||||
|
import com.moyz.adi.common.searchengine.SearchEngineContext;
|
||||||
import dev.langchain4j.model.openai.OpenAiModelName;
|
import dev.langchain4j.model.openai.OpenAiModelName;
|
||||||
import jakarta.annotation.PostConstruct;
|
import jakarta.annotation.PostConstruct;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
|
@ -43,16 +45,16 @@ public class Initializer {
|
||||||
|
|
||||||
//openai
|
//openai
|
||||||
String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING);
|
String[] openaiModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OPENAI_SETTING);
|
||||||
if(openaiModels.length == 0){
|
if (openaiModels.length == 0) {
|
||||||
log.warn("openai service is disabled");
|
log.warn("openai service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : openaiModels) {
|
for (String model : openaiModels) {
|
||||||
LLMContext.addLLMService(model, new OpenAiLLMService(model, proxy));
|
LLMContext.addLLMService(model, new OpenAiLLMService(model).setProxy(proxy));
|
||||||
}
|
}
|
||||||
|
|
||||||
//dashscope
|
//dashscope
|
||||||
String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING);
|
String[] dashscopeModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.DASHSCOPE_SETTING);
|
||||||
if(dashscopeModels.length == 0){
|
if (dashscopeModels.length == 0) {
|
||||||
log.warn("dashscope service is disabled");
|
log.warn("dashscope service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : dashscopeModels) {
|
for (String model : dashscopeModels) {
|
||||||
|
@ -61,7 +63,7 @@ public class Initializer {
|
||||||
|
|
||||||
//qianfan
|
//qianfan
|
||||||
String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING);
|
String[] qianfanModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.QIANFAN_SETTING);
|
||||||
if(qianfanModels.length == 0){
|
if (qianfanModels.length == 0) {
|
||||||
log.warn("qianfan service is disabled");
|
log.warn("qianfan service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : qianfanModels) {
|
for (String model : qianfanModels) {
|
||||||
|
@ -70,14 +72,18 @@ public class Initializer {
|
||||||
|
|
||||||
//ollama
|
//ollama
|
||||||
String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING);
|
String[] ollamaModels = LLMContext.getSupportModels(AdiConstant.SysConfigKey.OLLAMA_SETTING);
|
||||||
if(ollamaModels.length == 0){
|
if (ollamaModels.length == 0) {
|
||||||
log.warn("ollama service is disabled");
|
log.warn("ollama service is disabled");
|
||||||
}
|
}
|
||||||
for (String model : ollamaModels) {
|
for (String model : ollamaModels) {
|
||||||
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model));
|
LLMContext.addLLMService("ollama:" + model, new OllamaLLMService(model));
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2, proxy));
|
ImageModelContext.addImageModelService(OpenAiModelName.DALL_E_2, new OpenAiImageModelService(OpenAiModelName.DALL_E_2).setProxy(proxy));
|
||||||
|
|
||||||
|
|
||||||
|
//search engine
|
||||||
|
SearchEngineContext.addEngine(AdiConstant.SearchEngineName.GOOGLE, new GoogleSearchEngine().setProxy(proxy));
|
||||||
|
|
||||||
|
|
||||||
ragService.init();
|
ragService.init();
|
||||||
|
|
|
@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||||
import com.moyz.adi.common.base.ThreadContext;
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.dto.KbItemEditReq;
|
import com.moyz.adi.common.dto.KbItemEditReq;
|
||||||
import com.moyz.adi.common.entity.KnowledgeBase;
|
import com.moyz.adi.common.entity.KnowledgeBase;
|
||||||
import com.moyz.adi.common.entity.KnowledgeBaseItem;
|
import com.moyz.adi.common.entity.KnowledgeBaseItem;
|
||||||
|
@ -24,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static com.moyz.adi.common.cosntant.AdiConstant.RAG_TYPE_KB;
|
||||||
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -112,8 +114,8 @@ public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMappe
|
||||||
knowledgeBaseEmbeddingService.deleteByItemUuid(kbItem.getUuid());
|
knowledgeBaseEmbeddingService.deleteByItemUuid(kbItem.getUuid());
|
||||||
|
|
||||||
Metadata metadata = new Metadata();
|
Metadata metadata = new Metadata();
|
||||||
metadata.add("kb_uuid", kbItem.getKbUuid());
|
metadata.add(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbItem.getKbUuid());
|
||||||
metadata.add("kb_item_uuid", kbItem.getUuid());
|
metadata.add(AdiConstant.EmbeddingMetadataKey.KB_ITEM_UUID, kbItem.getUuid());
|
||||||
Document document = new Document(kbItem.getRemark(), metadata);
|
Document document = new Document(kbItem.getRemark(), metadata);
|
||||||
ragService.ingest(document);
|
ragService.ingest(document);
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRec
|
||||||
return baseMapper.selectOne(wrapper);
|
return baseMapper.selectOne(wrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean softDelele(String uuid) {
|
public boolean softDelete(String uuid) {
|
||||||
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
if (ThreadContext.getCurrentUser().getIsAdmin()) {
|
||||||
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
return ChainWrappers.lambdaUpdateChain(baseMapper)
|
||||||
.eq(KnowledgeBaseQaRecord::getUuid, uuid)
|
.eq(KnowledgeBaseQaRecord::getUuid, uuid)
|
||||||
|
|
|
@ -4,7 +4,9 @@ import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
import com.moyz.adi.common.base.ThreadContext;
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
import com.moyz.adi.common.cosntant.RedisKeyConstant;
|
||||||
import com.moyz.adi.common.dto.KbEditReq;
|
import com.moyz.adi.common.dto.KbEditReq;
|
||||||
import com.moyz.adi.common.dto.QAReq;
|
import com.moyz.adi.common.dto.QAReq;
|
||||||
|
@ -39,6 +41,7 @@ import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
|
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
|
||||||
|
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.QUOTA_BY_QA_ASK_DAILY;
|
||||||
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
import static com.moyz.adi.common.enums.ErrorEnum.*;
|
||||||
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
|
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
|
||||||
|
|
||||||
|
@ -161,9 +164,8 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
//向量化
|
//向量化
|
||||||
Document docWithoutPath = new Document(document.text());
|
Document docWithoutPath = new Document(document.text());
|
||||||
docWithoutPath.metadata()
|
docWithoutPath.metadata()
|
||||||
.add("kb_uuid", knowledgeBase.getUuid())
|
.add(AdiConstant.EmbeddingMetadataKey.KB_UUID, knowledgeBase.getUuid())
|
||||||
.add("kb_item_uuid", knowledgeBaseItem.getUuid());
|
.add(AdiConstant.EmbeddingMetadataKey.KB_ITEM_UUID, knowledgeBaseItem.getUuid());
|
||||||
|
|
||||||
ragService.ingest(docWithoutPath);
|
ragService.ingest(docWithoutPath);
|
||||||
|
|
||||||
knowledgeBaseItemService
|
knowledgeBaseItemService
|
||||||
|
@ -205,7 +207,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
LambdaQueryWrapper<KnowledgeBase> wrapper = new LambdaQueryWrapper();
|
LambdaQueryWrapper<KnowledgeBase> wrapper = new LambdaQueryWrapper();
|
||||||
wrapper.eq(KnowledgeBase::getIsPublic, true);
|
wrapper.eq(KnowledgeBase::getIsPublic, true);
|
||||||
wrapper.eq(KnowledgeBase::getIsDeleted, false);
|
wrapper.eq(KnowledgeBase::getIsDeleted, false);
|
||||||
if(StringUtils.isNotBlank(keyword)){
|
if (StringUtils.isNotBlank(keyword)) {
|
||||||
wrapper.like(KnowledgeBase::getTitle, keyword);
|
wrapper.like(KnowledgeBase::getTitle, keyword);
|
||||||
}
|
}
|
||||||
wrapper.orderByDesc(KnowledgeBase::getStarCount, KnowledgeBase::getUpdateTime);
|
wrapper.orderByDesc(KnowledgeBase::getStarCount, KnowledgeBase::getUpdateTime);
|
||||||
|
@ -223,7 +225,8 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
|
public KnowledgeBaseQaRecord ask(String kbUuid, String question, String modelName) {
|
||||||
checkRequestTimesOrThrow();
|
checkRequestTimesOrThrow();
|
||||||
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||||
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(kbUuid, question, modelName);
|
Map<String, String> metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid);
|
||||||
|
Pair<String, Response<AiMessage>> responsePair = ragService.retrieveAndAsk(metadataCond, question, modelName);
|
||||||
|
|
||||||
Response<AiMessage> ar = responsePair.getRight();
|
Response<AiMessage> ar = responsePair.getRight();
|
||||||
int inputTokenCount = ar.tokenUsage().inputTokenCount();
|
int inputTokenCount = ar.tokenUsage().inputTokenCount();
|
||||||
|
@ -235,7 +238,12 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
public SseEmitter sseAsk(String kbUuid, QAReq req) {
|
public SseEmitter sseAsk(String kbUuid, QAReq req) {
|
||||||
checkRequestTimesOrThrow();
|
checkRequestTimesOrThrow();
|
||||||
SseEmitter sseEmitter = new SseEmitter();
|
SseEmitter sseEmitter = new SseEmitter();
|
||||||
_this.retrieveAndPushToLLM(ThreadContext.getCurrentUser(), sseEmitter, kbUuid, req);
|
User user = ThreadContext.getCurrentUser();
|
||||||
|
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
sseEmitterHelper.startSse(user, sseEmitter);
|
||||||
|
_this.retrieveAndPushToLLM(user, sseEmitter, kbUuid, req);
|
||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,7 +261,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
private void checkRequestTimesOrThrow() {
|
private void checkRequestTimesOrThrow() {
|
||||||
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
|
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
|
||||||
String askTimes = stringRedisTemplate.opsForValue().get(key);
|
String askTimes = stringRedisTemplate.opsForValue().get(key);
|
||||||
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
|
String askQuota = SysConfigService.getByKey(QUOTA_BY_QA_ASK_DAILY);
|
||||||
if (null != askTimes && null != askQuota && Integer.parseInt(askTimes) >= Integer.parseInt(askQuota)) {
|
if (null != askTimes && null != askQuota && Integer.parseInt(askTimes) >= Integer.parseInt(askQuota)) {
|
||||||
throw new BaseException(A_QA_ASK_LIMIT);
|
throw new BaseException(A_QA_ASK_LIMIT);
|
||||||
}
|
}
|
||||||
|
@ -265,10 +273,10 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
|
public void retrieveAndPushToLLM(User user, SseEmitter sseEmitter, String kbUuid, QAReq req) {
|
||||||
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
|
log.info("retrieveAndPushToLLM,kbUuid:{},userId:{}", kbUuid, user.getId());
|
||||||
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
|
||||||
|
Map<String, String> metadataCond = ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.KB_UUID, kbUuid);
|
||||||
Prompt prompt = ragService.retrieveAndCreatePrompt(kbUuid, req.getQuestion());
|
Prompt prompt = ragService.retrieveAndCreatePrompt(metadataCond, req.getQuestion());
|
||||||
if(null == prompt){
|
if (null == prompt) {
|
||||||
sseEmitterHelper.sendAndComplete(sseEmitter, B_KNOWLEDGE_BASE_NO_ANSWER.getInfo());
|
sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String promptText = prompt.text();
|
String promptText = prompt.text();
|
||||||
|
@ -277,7 +285,7 @@ public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, Knowl
|
||||||
sseAskParams.setSseEmitter(sseEmitter);
|
sseAskParams.setSseEmitter(sseEmitter);
|
||||||
sseAskParams.setUserMessage(promptText);
|
sseAskParams.setUserMessage(promptText);
|
||||||
sseAskParams.setModelName(req.getModelName());
|
sseAskParams.setModelName(req.getModelName());
|
||||||
sseEmitterHelper.process(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||||
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens());
|
knowledgeBaseQaRecordService.createNewRecord(user, knowledgeBase, req.getQuestion(), promptText, promptMeta.getTokens(), response, answerMeta.getTokens());
|
||||||
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
|
userDayCostService.appendCostToUser(user, promptMeta.getTokens() + answerMeta.getTokens());
|
||||||
});
|
});
|
||||||
|
|
|
@ -7,14 +7,13 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||||
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
|
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTING;
|
import static com.moyz.adi.common.cosntant.AdiConstant.SysConfigKey.OLLAMA_SETTING;
|
||||||
|
|
||||||
public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
|
public class OllamaLLMService extends AbstractLLMService<OllamaSetting> {
|
||||||
|
|
||||||
public OllamaLLMService(String modelName) {
|
public OllamaLLMService(String modelName) {
|
||||||
super(modelName, OLLAMA_SETTING, OllamaSetting.class, null);
|
super(modelName, OLLAMA_SETTING, OllamaSetting.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -45,8 +45,8 @@ public class OpenAiImageModelService extends AbstractImageModelService<OpenAiSet
|
||||||
@Resource
|
@Resource
|
||||||
private ObjectMapper objectMapper;
|
private ObjectMapper objectMapper;
|
||||||
|
|
||||||
public OpenAiImageModelService(String modelName, Proxy proxy) {
|
public OpenAiImageModelService(String modelName) {
|
||||||
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy);
|
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,7 +17,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
|
|
||||||
import java.net.Proxy;
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.temporal.ChronoUnit;
|
import java.time.temporal.ChronoUnit;
|
||||||
|
|
||||||
|
@ -28,8 +27,8 @@ import java.time.temporal.ChronoUnit;
|
||||||
@Accessors(chain = true)
|
@Accessors(chain = true)
|
||||||
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
|
public class OpenAiLLMService extends AbstractLLMService<OpenAiSetting> {
|
||||||
|
|
||||||
public OpenAiLLMService(String modelName, Proxy proxy) {
|
public OpenAiLLMService(String modelName) {
|
||||||
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class, proxy);
|
super(modelName, AdiConstant.SysConfigKey.OPENAI_SETTING, OpenAiSetting.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,7 +19,7 @@ import org.apache.commons.lang3.StringUtils;
|
||||||
public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
|
public class QianFanLLMService extends AbstractLLMService<QianFanSetting> {
|
||||||
|
|
||||||
public QianFanLLMService(String modelName) {
|
public QianFanLLMService(String modelName) {
|
||||||
super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class, null);
|
super(modelName, AdiConstant.SysConfigKey.QIANFAN_SETTING, QianFanSetting.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
package com.moyz.adi.common.service;
|
package com.moyz.adi.common.service;
|
||||||
|
|
||||||
import com.moyz.adi.common.helper.LLMContext;
|
import com.moyz.adi.common.helper.LLMContext;
|
||||||
import com.moyz.adi.common.interfaces.TriConsumer;
|
|
||||||
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
|
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
|
||||||
import com.moyz.adi.common.vo.AnswerMeta;
|
|
||||||
import com.moyz.adi.common.vo.PromptMeta;
|
|
||||||
import dev.langchain4j.data.document.Document;
|
import dev.langchain4j.data.document.Document;
|
||||||
import dev.langchain4j.data.document.DocumentSplitter;
|
import dev.langchain4j.data.document.DocumentSplitter;
|
||||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||||
|
@ -14,7 +11,6 @@ import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
|
@ -24,35 +20,37 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.commons.lang3.tuple.Triple;
|
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
import static com.moyz.adi.common.cosntant.AdiConstant.PROMPT_TEMPLATE;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
import static java.util.stream.Collectors.joining;
|
import static java.util.stream.Collectors.joining;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
|
||||||
public class RAGService {
|
public class RAGService {
|
||||||
@Value("${spring.datasource.url}")
|
|
||||||
private String dataBaseUrl;
|
private String dataBaseUrl;
|
||||||
|
|
||||||
@Value("${spring.datasource.username}")
|
|
||||||
private String dataBaseUserName;
|
private String dataBaseUserName;
|
||||||
|
|
||||||
@Value("${spring.datasource.password}")
|
|
||||||
private String dataBasePassword;
|
private String dataBasePassword;
|
||||||
private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}");
|
|
||||||
|
private String tableName;
|
||||||
|
|
||||||
private EmbeddingModel embeddingModel;
|
private EmbeddingModel embeddingModel;
|
||||||
|
|
||||||
private EmbeddingStore<TextSegment> embeddingStore;
|
private EmbeddingStore<TextSegment> embeddingStore;
|
||||||
|
|
||||||
|
public RAGService(String tableName, String dataBaseUrl, String dataBaseUserName, String dataBasePassword) {
|
||||||
|
this.tableName = tableName;
|
||||||
|
this.dataBasePassword = dataBasePassword;
|
||||||
|
this.dataBaseUserName = dataBaseUserName;
|
||||||
|
this.dataBaseUrl = dataBaseUrl;
|
||||||
|
}
|
||||||
|
|
||||||
public void init() {
|
public void init() {
|
||||||
log.info("initEmbeddingModel");
|
log.info("initEmbeddingModel");
|
||||||
embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
||||||
|
@ -88,7 +86,7 @@ public class RAGService {
|
||||||
.dimension(384)
|
.dimension(384)
|
||||||
.createTable(true)
|
.createTable(true)
|
||||||
.dropTableFirst(false)
|
.dropTableFirst(false)
|
||||||
.table("adi_knowledge_base_embedding")
|
.table(tableName)
|
||||||
.build();
|
.build();
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
}
|
}
|
||||||
|
@ -112,7 +110,14 @@ public class RAGService {
|
||||||
getEmbeddingStoreIngestor().ingest(document);
|
getEmbeddingStoreIngestor().ingest(document);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Prompt retrieveAndCreatePrompt(String kbUuid, String question) {
|
/**
|
||||||
|
* Retrieve documents and create prompt
|
||||||
|
*
|
||||||
|
* @param metadataCond Query condition
|
||||||
|
* @param question User's question
|
||||||
|
* @return Document in the vector db
|
||||||
|
*/
|
||||||
|
public Prompt retrieveAndCreatePrompt(Map<String, String> metadataCond, String question) {
|
||||||
// Embed the question
|
// Embed the question
|
||||||
Embedding questionEmbedding = embeddingModel.embed(question).content();
|
Embedding questionEmbedding = embeddingModel.embed(question).content();
|
||||||
|
|
||||||
|
@ -120,7 +125,7 @@ public class RAGService {
|
||||||
// You can play with parameters below to find a sweet spot for your specific use case
|
// You can play with parameters below to find a sweet spot for your specific use case
|
||||||
int maxResults = 3;
|
int maxResults = 3;
|
||||||
double minScore = 0.6;
|
double minScore = 0.6;
|
||||||
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByKbUuid(kbUuid, questionEmbedding, maxResults, minScore);
|
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByMetadata(metadataCond, questionEmbedding, maxResults, minScore);
|
||||||
|
|
||||||
// Create a prompt for the model that includes question and relevant embeddings
|
// Create a prompt for the model that includes question and relevant embeddings
|
||||||
String information = relevantEmbeddings.stream()
|
String information = relevantEmbeddings.stream()
|
||||||
|
@ -130,23 +135,28 @@ public class RAGService {
|
||||||
if (StringUtils.isBlank(information)) {
|
if (StringUtils.isBlank(information)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 召回并提问
|
* 召回并提问
|
||||||
*
|
*
|
||||||
* @param kbUuid 知识库uuid
|
* @param metadataCond metadata condition
|
||||||
* @param question 用户的问题
|
* @param question user's question
|
||||||
* @param modelName LLM model name
|
* @param modelName LLM model name
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Pair<String, Response<AiMessage>> retrieveAndAsk(String kbUuid, String question, String modelName) {
|
public Pair<String, Response<AiMessage>> retrieveAndAsk(Map<String, String> metadataCond, String question, String modelName) {
|
||||||
Prompt prompt = retrieveAndCreatePrompt(kbUuid, question);
|
|
||||||
|
Prompt prompt = retrieveAndCreatePrompt(metadataCond, question);
|
||||||
if (null == prompt) {
|
if (null == prompt) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
Response<AiMessage> response = new LLMContext(modelName).getLLMService().chat(prompt.toUserMessage());
|
||||||
return new ImmutablePair<>(prompt.text(), response);
|
return new ImmutablePair<>(prompt.text(), response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static final String parsePromptTemplate(String question, String information) {
|
||||||
|
return PROMPT_TEMPLATE.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information))).text();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,250 @@
|
||||||
|
package com.moyz.adi.common.service;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import com.moyz.adi.common.base.ThreadContext;
|
||||||
|
import com.moyz.adi.common.cosntant.AdiConstant;
|
||||||
|
import com.moyz.adi.common.dto.SearchEngineResp;
|
||||||
|
import com.moyz.adi.common.dto.SearchResult;
|
||||||
|
import com.moyz.adi.common.dto.SearchResultItem;
|
||||||
|
import com.moyz.adi.common.entity.AiSearchRecord;
|
||||||
|
import com.moyz.adi.common.entity.User;
|
||||||
|
import com.moyz.adi.common.helper.SSEEmitterHelper;
|
||||||
|
import com.moyz.adi.common.searchengine.SearchEngineContext;
|
||||||
|
import com.moyz.adi.common.vo.SseAskParams;
|
||||||
|
import dev.langchain4j.data.document.Document;
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.jsoup.Jsoup;
|
||||||
|
import org.jsoup.safety.Cleaner;
|
||||||
|
import org.jsoup.safety.Safelist;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
import org.springframework.core.task.AsyncTaskExecutor;
|
||||||
|
import org.springframework.scheduling.annotation.Async;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
import static com.moyz.adi.common.enums.ErrorEnum.B_NO_ANSWER;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RAG search
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class SearchService {
|
||||||
|
|
||||||
|
@Lazy
|
||||||
|
@Resource
|
||||||
|
private SearchService _this;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private RAGService searchRagService;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private SSEEmitterHelper sseEmitterHelper;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private AiSearchRecordService aiSearchRecordService;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private AsyncTaskExecutor mainExecutor;
|
||||||
|
|
||||||
|
public SseEmitter search(boolean isBriefSearch, String searchText, String engineName, String modelName) {
|
||||||
|
User user = ThreadContext.getCurrentUser();
|
||||||
|
SseEmitter sseEmitter = new SseEmitter();
|
||||||
|
if (!sseEmitterHelper.checkOrComplete(user, sseEmitter)) {
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
sseEmitterHelper.startSse(user, sseEmitter);
|
||||||
|
_this.asyncSearch(user, sseEmitter, isBriefSearch, searchText, engineName, modelName);
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Async
|
||||||
|
public void asyncSearch(User user, SseEmitter sseEmitter, boolean isBriefSearch, String searchText, String engineName, String modelName) {
|
||||||
|
SearchResult searchResult = new SearchEngineContext(engineName).getEngine().search(searchText);
|
||||||
|
if (StringUtils.isNotBlank(searchResult.getErrorMessage())) {
|
||||||
|
sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, searchResult.getErrorMessage());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isEmpty(searchResult.getItems())) {
|
||||||
|
sseEmitterHelper.sendAndComplete(user.getId(), sseEmitter, B_NO_ANSWER.getInfo());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
boolean sendFail = false;
|
||||||
|
try {
|
||||||
|
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.AI_SEARCH_SOURCE_LINKS).data(searchResult.getItems()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
sendFail = true;
|
||||||
|
log.error("asyncSearch error", e);
|
||||||
|
sseEmitterHelper.sendErrorAndComplete(user.getId(), sseEmitter, e.getMessage());
|
||||||
|
}
|
||||||
|
if (sendFail) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (isBriefSearch) {
|
||||||
|
briefSearch(user, searchText, modelName, searchResult.getItems(), sseEmitter);
|
||||||
|
} else {
|
||||||
|
detailSearch(user, searchText, engineName, modelName, searchResult.getItems(), sseEmitter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 1.Search by search engine
|
||||||
|
* 2.Create prompt by search response
|
||||||
|
* 3.Send prompt to llm
|
||||||
|
*
|
||||||
|
* @param user
|
||||||
|
* @param searchText
|
||||||
|
* @param modelName
|
||||||
|
* @param resultItems
|
||||||
|
* @param sseEmitter
|
||||||
|
*/
|
||||||
|
public void briefSearch(User user, String searchText, String modelName, List<SearchResultItem> resultItems, SseEmitter sseEmitter) {
|
||||||
|
log.info("briefSearch,searchText:{}", searchText);
|
||||||
|
StringBuilder builder = new StringBuilder();
|
||||||
|
for (SearchResultItem item : resultItems) {
|
||||||
|
builder.append(item.getSnippet()).append("\n\n");
|
||||||
|
}
|
||||||
|
String ragQuestion = builder.toString();
|
||||||
|
String prompt = RAGService.parsePromptTemplate(searchText, ragQuestion);
|
||||||
|
|
||||||
|
SearchEngineResp resp = new SearchEngineResp().setItems(resultItems);
|
||||||
|
|
||||||
|
SseAskParams sseAskParams = new SseAskParams();
|
||||||
|
sseAskParams.setSystemMessage(StringUtils.EMPTY);
|
||||||
|
sseAskParams.setSseEmitter(sseEmitter);
|
||||||
|
sseAskParams.setUserMessage(prompt);
|
||||||
|
sseAskParams.setModelName(modelName);
|
||||||
|
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||||
|
AiSearchRecord newRecord = new AiSearchRecord();
|
||||||
|
newRecord.setUuid(UUID.randomUUID().toString().replace("-", ""));
|
||||||
|
newRecord.setQuestion(searchText);
|
||||||
|
newRecord.setSearchEngineResp(resp);
|
||||||
|
newRecord.setPrompt(prompt);
|
||||||
|
newRecord.setPromptTokens(promptMeta.getTokens());
|
||||||
|
newRecord.setAnswer(response);
|
||||||
|
newRecord.setAnswerTokens(answerMeta.getTokens());
|
||||||
|
newRecord.setUserUuid(user.getUuid());
|
||||||
|
newRecord.setUserId(user.getId());
|
||||||
|
aiSearchRecordService.save(newRecord);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 1.Search by search engine
|
||||||
|
* 2.Save the response to pgvector
|
||||||
|
* 3.Retrieve document and create prompt
|
||||||
|
* 4.Send prompt to llm
|
||||||
|
*
|
||||||
|
* @param user
|
||||||
|
* @param searchText
|
||||||
|
* @param engineName
|
||||||
|
* @param modelName
|
||||||
|
* @param resultItems
|
||||||
|
* @param sseEmitter
|
||||||
|
*/
|
||||||
|
public void detailSearch(User user, String searchText, String engineName, String modelName, List<SearchResultItem> resultItems, SseEmitter sseEmitter) {
|
||||||
|
log.info("detailSearch,searchText:{}", searchText);
|
||||||
|
//Save to DB
|
||||||
|
SearchEngineResp resp = new SearchEngineResp().setItems(resultItems);
|
||||||
|
AiSearchRecord newRecord = new AiSearchRecord();
|
||||||
|
String searchUuid = UUID.randomUUID().toString().replace("-", "");
|
||||||
|
newRecord.setUuid(searchUuid);
|
||||||
|
newRecord.setQuestion(searchText);
|
||||||
|
newRecord.setSearchEngineResp(resp);
|
||||||
|
newRecord.setUserId(user.getId());
|
||||||
|
newRecord.setUserUuid(user.getUuid());
|
||||||
|
aiSearchRecordService.save(newRecord);
|
||||||
|
|
||||||
|
CountDownLatch countDownLatch = new CountDownLatch(resultItems.size());
|
||||||
|
for (int i = 0; i < resultItems.size(); i++) {
|
||||||
|
int finalI = i;
|
||||||
|
mainExecutor.execute(() -> {
|
||||||
|
try {
|
||||||
|
SearchResultItem item = resultItems.get(finalI);
|
||||||
|
String content;
|
||||||
|
if (finalI < 2) {
|
||||||
|
content = getContentFromRemote(item);
|
||||||
|
|
||||||
|
//Fill content with html body text
|
||||||
|
item.setContent(content);
|
||||||
|
} else {
|
||||||
|
content = item.getSnippet();
|
||||||
|
}
|
||||||
|
|
||||||
|
//embedding
|
||||||
|
Metadata metadata = new Metadata();
|
||||||
|
metadata.add(AdiConstant.EmbeddingMetadataKey.ENGINE_NAME, engineName);
|
||||||
|
metadata.add(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid);
|
||||||
|
Document document = new Document(content, metadata);
|
||||||
|
searchRagService.ingest(document);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Detail search error,uuid:{}", searchUuid, e);
|
||||||
|
} finally {
|
||||||
|
countDownLatch.countDown();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
countDownLatch.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
log.error("CountDownLatch await error,uuid:{}", searchUuid, e);
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Create prompt");
|
||||||
|
Prompt prompt = searchRagService.retrieveAndCreatePrompt(ImmutableMap.of(AdiConstant.EmbeddingMetadataKey.SEARCH_UUID, searchUuid), searchText);
|
||||||
|
|
||||||
|
SseAskParams sseAskParams = new SseAskParams();
|
||||||
|
sseAskParams.setSystemMessage(StringUtils.EMPTY);
|
||||||
|
sseAskParams.setSseEmitter(sseEmitter);
|
||||||
|
sseAskParams.setUserMessage(prompt.text());
|
||||||
|
sseAskParams.setModelName(modelName);
|
||||||
|
|
||||||
|
log.info("Push to model");
|
||||||
|
sseEmitterHelper.processAndPushToModel(user, sseAskParams, (response, promptMeta, answerMeta) -> {
|
||||||
|
|
||||||
|
AiSearchRecord existRecord = aiSearchRecordService.lambdaQuery().eq(AiSearchRecord::getUuid, searchUuid).one();
|
||||||
|
|
||||||
|
AiSearchRecord updateRecord = new AiSearchRecord();
|
||||||
|
updateRecord.setId(existRecord.getId());
|
||||||
|
//Update search engine response content.(with html body text)
|
||||||
|
updateRecord.setSearchEngineResp(new SearchEngineResp().setItems(resultItems));
|
||||||
|
updateRecord.setPrompt(prompt.text());
|
||||||
|
updateRecord.setPromptTokens(promptMeta.getTokens());
|
||||||
|
updateRecord.setAnswer(response);
|
||||||
|
updateRecord.setAnswerTokens(answerMeta.getTokens());
|
||||||
|
aiSearchRecordService.updateById(updateRecord);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getContentFromRemote(SearchResultItem item) {
|
||||||
|
String result = "";
|
||||||
|
try {
|
||||||
|
org.jsoup.nodes.Document doc = Jsoup.connect(item.getLink()).ignoreContentType(true).get();
|
||||||
|
if (doc.getElementsByTag("main").size() > 0) {
|
||||||
|
result = doc.getElementsByTag("main").get(0).html();
|
||||||
|
} else {
|
||||||
|
result = doc.body().html();
|
||||||
|
}
|
||||||
|
if (StringUtils.isBlank(result)) {
|
||||||
|
log.error("Empty content from {}, use snippet instead", item.getLink());
|
||||||
|
return item.getSnippet();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to load document from {}, use snippet instead", item.getLink(), e);
|
||||||
|
}
|
||||||
|
Cleaner cleaner = new Cleaner(Safelist.none());
|
||||||
|
return cleaner.clean(Jsoup.parse(result)).text();
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.math.NumberUtils;
|
import org.apache.commons.lang3.math.NumberUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -286,16 +287,24 @@ public class AdiPgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//adi
|
//adi
|
||||||
public List<EmbeddingMatch<TextSegment>> findRelevantByKbUuid(String kbUuid, Embedding referenceEmbedding, int maxResults, double minScore) {
|
public List<EmbeddingMatch<TextSegment>> findRelevantByMetadata(Map<String, String> metadatCondition, Embedding referenceEmbedding, int maxResults, double minScore) {
|
||||||
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
|
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
|
||||||
try (Connection connection = setupConnection()) {
|
try (Connection connection = setupConnection()) {
|
||||||
String referenceVector = Arrays.toString(referenceEmbedding.vector());
|
String referenceVector = Arrays.toString(referenceEmbedding.vector());
|
||||||
//新增查询条件kb_id
|
//deal with metadata condition
|
||||||
|
StringBuilder whereSql = new StringBuilder();
|
||||||
|
if (null != metadatCondition && !metadatCondition.isEmpty()) {
|
||||||
|
whereSql = new StringBuilder("where");
|
||||||
|
for (String key : metadatCondition.keySet()) {
|
||||||
|
whereSql.append(" metadata->>'").append(key).append("' = '").append(metadatCondition.get(key)).append("' and");
|
||||||
|
}
|
||||||
|
whereSql.replace(whereSql.length() - 3, whereSql.length(), "");
|
||||||
|
}
|
||||||
String query = String.format(
|
String query = String.format(
|
||||||
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s where metadata->>'kb_uuid' = '%s') SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
|
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s " + whereSql + ") SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
|
||||||
referenceVector, table, kbUuid, minScore, maxResults);
|
referenceVector, table, minScore, maxResults);
|
||||||
|
log.info(query);
|
||||||
PreparedStatement selectStmt = connection.prepareStatement(query);
|
PreparedStatement selectStmt = connection.prepareStatement(query);
|
||||||
|
|
||||||
ResultSet resultSet = selectStmt.executeQuery();
|
ResultSet resultSet = selectStmt.executeQuery();
|
||||||
|
|
|
@ -10,6 +10,7 @@ import org.apache.commons.collections4.CollectionUtils;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.BiConsumer;
|
import java.util.function.BiConsumer;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
import java.util.function.ObjLongConsumer;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
public class BizPager {
|
public class BizPager {
|
||||||
|
@ -68,6 +69,21 @@ public class BizPager {
|
||||||
} while (!Thread.currentThread().isInterrupted() && records.size() == AdiConstant.DEFAULT_PAGE_SIZE);
|
} while (!Thread.currentThread().isInterrupted() && records.size() == AdiConstant.DEFAULT_PAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <T extends BaseEntity> void listByMaxId(Long maxId, LambdaQueryWrapper<T> queryWrapper, IService<T> service, SFunction<T, Long> idSupplier, ObjLongConsumer<List<T>> consumer) {
|
||||||
|
if (maxId > 0) {
|
||||||
|
queryWrapper.lt(idSupplier, maxId);
|
||||||
|
}
|
||||||
|
queryWrapper.orderByDesc(idSupplier);
|
||||||
|
queryWrapper.last("limit " + AdiConstant.DEFAULT_PAGE_SIZE);
|
||||||
|
|
||||||
|
long minId = 0;
|
||||||
|
List<T> records = service.list(queryWrapper);
|
||||||
|
if (CollectionUtils.isNotEmpty(records)) {
|
||||||
|
minId = records.stream().map(idSupplier).reduce(Long::min).get();
|
||||||
|
}
|
||||||
|
consumer.accept(records, minId);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 以Long类型的惟一字段(通常为id)为锚点,按页获取数据
|
* 以Long类型的惟一字段(通常为id)为锚点,按页获取数据
|
||||||
* <br/>不依赖mybatis-plus
|
* <br/>不依赖mybatis-plus
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package com.moyz.adi.common.util;
|
||||||
|
|
||||||
|
import org.springframework.beans.BeansException;
|
||||||
|
import org.springframework.context.ApplicationContext;
|
||||||
|
import org.springframework.context.ApplicationContextAware;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
public class SpringUtil implements ApplicationContextAware {
|
||||||
|
|
||||||
|
private static ApplicationContext applicationContext;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||||
|
SpringUtil.applicationContext = applicationContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T getBean(String name) {
|
||||||
|
return (T) applicationContext.getBean(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <T> T getBean(Class<T> clazz) {
|
||||||
|
return applicationContext.getBean(clazz);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
package com.moyz.adi.common.vo;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class GoogleSetting {
|
||||||
|
private String url;
|
||||||
|
private String key;
|
||||||
|
private String cx;
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package com.moyz.adi.common.vo;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.moyz.adi.common.interfaces.AbstractSearchEngine;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SearchEngineInfo {
|
||||||
|
private String name;
|
||||||
|
private Boolean enable;
|
||||||
|
@JsonIgnore
|
||||||
|
private AbstractSearchEngine engine;
|
||||||
|
}
|
|
@ -390,6 +390,9 @@ VALUES ('qianfan_setting', '{"api_key":"","secret_key":"","models":[]}');
|
||||||
INSERT INTO adi_sys_config (name, value)
|
INSERT INTO adi_sys_config (name, value)
|
||||||
VALUES ('ollama_setting', '{"base_url":"","models":[]}');
|
VALUES ('ollama_setting', '{"base_url":"","models":[]}');
|
||||||
INSERT INTO adi_sys_config (name, value)
|
INSERT INTO adi_sys_config (name, value)
|
||||||
|
VALUES ('google_setting',
|
||||||
|
'{"url":"https://www.googleapis.com/customsearch/v1","key":"","cx":""}');
|
||||||
|
INSERT INTO adi_sys_config (name, value)
|
||||||
VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}');
|
VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}');
|
||||||
INSERT INTO adi_sys_config (name, value)
|
INSERT INTO adi_sys_config (name, value)
|
||||||
VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}');
|
VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}');
|
||||||
|
@ -544,3 +547,48 @@ create trigger trigger_kb_qa_record_update_time
|
||||||
on adi_knowledge_base_qa_record
|
on adi_knowledge_base_qa_record
|
||||||
for each row
|
for each row
|
||||||
execute procedure update_modified_column();
|
execute procedure update_modified_column();
|
||||||
|
|
||||||
|
-- ai search
|
||||||
|
create table adi_ai_search_record
|
||||||
|
(
|
||||||
|
id bigserial primary key,
|
||||||
|
uuid varchar(32) default ''::character varying not null,
|
||||||
|
question varchar(1000) default ''::character varying not null,
|
||||||
|
search_engine_response jsonb not null,
|
||||||
|
prompt text default ''::character varying not null,
|
||||||
|
prompt_tokens integer DEFAULT 0 NOT NULL,
|
||||||
|
answer text default ''::character varying not null,
|
||||||
|
answer_tokens integer DEFAULT 0 NOT NULL,
|
||||||
|
user_id bigint default '0' NOT NULL,
|
||||||
|
user_uuid varchar(32) default ''::character varying not null,
|
||||||
|
create_time timestamp default CURRENT_TIMESTAMP not null,
|
||||||
|
update_time timestamp default CURRENT_TIMESTAMP not null,
|
||||||
|
is_deleted boolean default false not null
|
||||||
|
);
|
||||||
|
comment on table adi_ai_search_record is 'Search record';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.question is 'User original question';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.search_engine_response is 'Search engine''s response content';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.prompt is 'Prompt of LLM';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.prompt_tokens is 'prompt消耗的token数量';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.answer is 'LLM response';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.answer_tokens is 'LLM响应消耗的token数量';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.user_id is 'Id from adi_user';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.create_time is '创建时间';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.update_time is '更新时间';
|
||||||
|
|
||||||
|
comment on column adi_ai_search_record.is_deleted is '0: Normal; 1: Deleted';
|
||||||
|
|
||||||
|
create trigger trigger_ai_search_record
|
||||||
|
before update
|
||||||
|
on adi_ai_search_record
|
||||||
|
for each row
|
||||||
|
execute procedure update_modified_column();
|
Loading…
Reference in New Issue