增加RAG大模型知识库

This commit is contained in:
moyangzhan 2024-02-15 14:29:57 +08:00
parent a2c1ed5390
commit 029d6b9d50
76 changed files with 2596 additions and 424 deletions

View File

@ -1,4 +1,4 @@
# Getting Started ## Getting Started
> 声明:此项目只发布于 Github基于 MIT 协议,免费且作为开源学习使用。并且不会有任何形式的卖号等行为,谨防受骗。 > 声明:此项目只发布于 Github基于 MIT 协议,免费且作为开源学习使用。并且不会有任何形式的卖号等行为,谨防受骗。
@ -20,19 +20,19 @@
#### 初始化 #### 初始化
* 初始化数据库 初始化数据库
* 创建数据库aideepin * 创建数据库aideepin
* 执行docs/create.sql * 执行docs/create.sql
* 填充openai的secret_key * 填充openai的secret\_key
``` ```plaintext
update adi_sys_config set value = 'my_chatgpt_secret_key' where name = 'secret_key' update adi_sys_config set value = 'my_chatgpt_secret_key' where name = 'secret_key'
``` ```
* 修改配置文件 * 修改配置文件
* mysql: application-[dev|prod].xml中的spring.datasource * postgresql: application-[dev|prod].xml中的spring.datasource
* redis: application-[dev|prod].xml中的spring.data.redis * redis: application-[dev|prod].xml中的spring.data.redis
* mail: application.xml中的spring.mail * mail: application.xml中的spring.mail
@ -40,7 +40,7 @@ update adi_sys_config set value = 'my_chatgpt_secret_key' where name = 'secret_k
* 进入项目 * 进入项目
``` ```plaintext
cd aideepin cd aideepin
``` ```
@ -54,14 +54,14 @@ mvn clean package -Dmaven.test.skip=true
a. jar包启动 a. jar包启动
``` ```plaintext
cd adi-chat/target cd adi-chat/target
nohup java -jar -Xms768m -Xmx1024m -XX:+HeapDumpOnOutOfMemoryError adi-chat-0.0.1-SNAPSHOT.jar --spring.profiles.active=[dev|prod] dev/null 2>&1 & nohup java -jar -Xms768m -Xmx1024m -XX:+HeapDumpOnOutOfMemoryError adi-chat-0.0.1-SNAPSHOT.jar --spring.profiles.active=[dev|prod] dev/null 2>&1 &
``` ```
b. docker启动 b. docker启动
``` ```plaintext
cd adi-chat cd adi-chat
docker build . -t aideepin:0.0.1 docker build . -t aideepin:0.0.1
docker run -d \ docker run -d \

View File

@ -11,18 +11,7 @@
<artifactId>adi-admin</artifactId> <artifactId>adi-admin</artifactId>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies> <dependencies>
<dependency>
<groupId>com.moyz</groupId>
<artifactId>adi-bootstrap</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
<dependency> <dependency>
<groupId>com.moyz</groupId> <groupId>com.moyz</groupId>
<artifactId>adi-common</artifactId> <artifactId>adi-common</artifactId>

View File

@ -0,0 +1,25 @@
package com.moyz.adi.admin.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.entity.SysConfig;
import com.moyz.adi.common.service.KnowledgeBaseService;
import com.moyz.adi.common.service.SysConfigService;
import jakarta.annotation.Resource;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/admin/sys-config")
@Validated
public class SystemConfigController {
@Resource
private SysConfigService sysConfigService;
public Page<SysConfig> list(String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return sysConfigService.search(keyword, currentPage, pageSize);
}
}

View File

@ -11,6 +11,18 @@
<artifactId>adi-bootstrap</artifactId> <artifactId>adi-bootstrap</artifactId>
<dependencies>
<dependency>
<groupId>com.moyz</groupId>
<artifactId>adi-chat</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>com.moyz</groupId>
<artifactId>adi-admin</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
</dependencies>
<build> <build>
<resources> <resources>
<resource> <resource>

View File

@ -1,8 +1,8 @@
spring: spring:
datasource: datasource:
driver-class-name: com.mysql.cj.jdbc.Driver driver-class-name: org.postgresql.Driver
url: jdbc:mysql://localhost:3306/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true url: jdbc:postgresql://172.17.18.164:5432/aideepin2?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true
username: root username: postgres
password: 123456 password: 123456
data: data:
redis: redis:

View File

@ -1,9 +1,9 @@
spring: spring:
datasource: datasource:
driver-class-name: com.mysql.cj.jdbc.Driver driver-class-name: org.postgresql.Driver
url: jdbc:mysql://localhost:3306/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true url: jdbc:postgresql://localhost:3306/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true
username: your-mysql-account username: your-db-account
password: your-mysql-password password: your-db-password
data: data:
redis: redis:
host: localhost host: localhost

View File

@ -12,11 +12,6 @@
<artifactId>adi-chat</artifactId> <artifactId>adi-chat</artifactId>
<dependencies> <dependencies>
<dependency>
<groupId>com.moyz</groupId>
<artifactId>adi-bootstrap</artifactId>
<version>0.0.1-SNAPSHOT</version>
</dependency>
<dependency> <dependency>
<groupId>com.moyz</groupId> <groupId>com.moyz</groupId>
<artifactId>adi-common</artifactId> <artifactId>adi-common</artifactId>

View File

@ -38,7 +38,7 @@ public class FileController {
@PostMapping(path = "/file/upload", headers = "content-type=multipart/form-data", produces = MediaType.APPLICATION_JSON_VALUE) @PostMapping(path = "/file/upload", headers = "content-type=multipart/form-data", produces = MediaType.APPLICATION_JSON_VALUE)
public Map<String, String> upload(@RequestPart(value = "file") MultipartFile file) { public Map<String, String> upload(@RequestPart(value = "file") MultipartFile file) {
Map<String, String> result = new HashMap<>(); Map<String, String> result = new HashMap<>();
result.put("uuid", fileService.writeToLocal(file)); result.put("uuid", fileService.writeToLocal(file).getUuid());
return result; return result;
} }

View File

@ -0,0 +1,62 @@
package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.KbEditReq;
import com.moyz.adi.common.entity.AdiFile;
import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.service.KnowledgeBaseService;
import jakarta.annotation.Resource;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import org.springframework.http.MediaType;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
@RestController
@RequestMapping("/knowledge-base")
@Validated
public class KnowledgeBaseController {
@Resource
private KnowledgeBaseService knowledgeBaseService;
@PostMapping("/saveOrUpdate")
public KnowledgeBase saveOrUpdate(@RequestBody KbEditReq kbEditReq) {
return knowledgeBaseService.saveOrUpdate(kbEditReq);
}
@PostMapping(path = "/uploadDocs/{uuid}", headers = "content-type=multipart/form-data", produces = MediaType.APPLICATION_JSON_VALUE)
public boolean uploadDocs(@PathVariable String uuid,@RequestParam(value = "embedding", defaultValue = "true") Boolean embedding, @RequestParam("files") MultipartFile[] docs) {
knowledgeBaseService.uploadDocs(uuid, embedding, docs);
return true;
}
@PostMapping(path = "/upload/{uuid}", headers = "content-type=multipart/form-data", produces = MediaType.APPLICATION_JSON_VALUE)
public AdiFile upload(@PathVariable String uuid, @RequestParam(value = "embedding", defaultValue = "true") Boolean embedding, @RequestParam("file") MultipartFile doc) {
return knowledgeBaseService.uploadDoc(uuid, embedding, doc);
}
@GetMapping("/search")
public Page<KnowledgeBase> list(String keyword, Boolean includeOthersPublic, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return knowledgeBaseService.search(keyword, includeOthersPublic, currentPage, pageSize);
}
@GetMapping("/info/{uuid}")
public KnowledgeBase info(@PathVariable String uuid) {
return knowledgeBaseService.lambdaQuery()
.eq(KnowledgeBase::getUuid, uuid)
.eq(KnowledgeBase::getIsDeleted, false)
.one();
}
@PostMapping("/del/{uuid}")
public boolean softDelete(@PathVariable String uuid) {
return knowledgeBaseService.softDelete(uuid);
}
@PostMapping("/embedding/{uuid}")
public boolean embedding(@PathVariable String uuid, @RequestParam(defaultValue = "false") Boolean forceAll) {
return knowledgeBaseService.embedding(uuid, forceAll);
}
}

View File

@ -0,0 +1,25 @@
package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.KbItemEmbeddingDto;
import com.moyz.adi.common.service.KnowledgeBaseEmbeddingService;
import jakarta.annotation.Resource;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/knowledge-base-embedding")
@Validated
public class KnowledgeBaseEmbeddingController {
@Resource
private KnowledgeBaseEmbeddingService knowledgeBaseEmbeddingService;
@GetMapping("/list/{kbItemUuid}")
public Page<KbItemEmbeddingDto> list(@PathVariable String kbItemUuid, int currentPage, int pageSize) {
return knowledgeBaseEmbeddingService.listByItemUuid(kbItemUuid, currentPage, pageSize);
}
}

View File

@ -0,0 +1,56 @@
package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.KbItemEditReq;
import com.moyz.adi.common.dto.KbItemEmbeddingBatchReq;
import com.moyz.adi.common.entity.KnowledgeBaseItem;
import com.moyz.adi.common.service.KnowledgeBaseItemService;
import jakarta.annotation.Resource;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@RequestMapping("/knowledge-base-item")
@Validated
public class KnowledgeBaseItemController {
@Resource
private KnowledgeBaseItemService knowledgeBaseItemService;
@PostMapping("/saveOrUpdate")
public KnowledgeBaseItem saveOrUpdate(@RequestBody KbItemEditReq itemEditReq) {
return knowledgeBaseItemService.saveOrUpdate(itemEditReq);
}
@GetMapping("/search")
public Page<KnowledgeBaseItem> search(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return knowledgeBaseItemService.search(kbUuid, keyword, currentPage, pageSize);
}
@GetMapping("/info/{uuid}")
public KnowledgeBaseItem info(@PathVariable String uuid) {
return knowledgeBaseItemService.lambdaQuery()
.eq(KnowledgeBaseItem::getUuid, uuid)
.eq(KnowledgeBaseItem::getIsDeleted, false)
.one();
}
@PostMapping("/embedding/{uuid}")
public boolean embedding(@PathVariable String uuid) {
return knowledgeBaseItemService.checkAndEmbedding(uuid);
}
@PostMapping("/embedding-list")
public boolean embeddingBatch(@RequestBody KbItemEmbeddingBatchReq req) {
return knowledgeBaseItemService.checkAndEmbedding(req.getUuids());
}
@PostMapping("/del/{uuid}")
public boolean softDelete(@PathVariable String uuid) {
return knowledgeBaseItemService.softDelete(uuid);
}
}

View File

@ -0,0 +1,40 @@
package com.moyz.adi.chat.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.dto.QAReq;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.service.KnowledgeBaseQaRecordService;
import com.moyz.adi.common.service.KnowledgeBaseService;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
@Tag(name = "知识库问答controller")
@RequestMapping("/knowledge-base/qa/")
@RestController
public class KnowledgeBaseQAController {
@Resource
private KnowledgeBaseService knowledgeBaseService;
@Resource
private KnowledgeBaseQaRecordService knowledgeBaseQaRecordService;
@PostMapping("/ask/{kbUuid}")
public KnowledgeBaseQaRecord ask(@PathVariable String kbUuid, @RequestBody @Validated QAReq req) {
return knowledgeBaseService.answerAndRecord(kbUuid, req.getQuestion());
}
@GetMapping("/record/search")
public Page<KnowledgeBaseQaRecord> list(String kbUuid, String keyword, @NotNull @Min(1) Integer currentPage, @NotNull @Min(10) Integer pageSize) {
return knowledgeBaseQaRecordService.search(kbUuid, keyword, currentPage, pageSize);
}
@PostMapping("/record/del/{uuid}")
public boolean recordDel(@PathVariable String uuid) {
return knowledgeBaseQaRecordService.softDelele(uuid);
}
}

View File

@ -32,7 +32,7 @@ public class UserController {
@Operation(summary = "用户信息") @Operation(summary = "用户信息")
@GetMapping("/{uuid}") @GetMapping("/{uuid}")
public void login(@Validated @PathVariable String uuid) { public void info(@Validated @PathVariable String uuid) {
log.info(uuid); log.info(uuid);
} }

View File

@ -10,7 +10,7 @@ import java.util.Collections;
public class CodeGenerator { public class CodeGenerator {
public static void main(String[] args) { public static void main(String[] args) {
FastAutoGenerator.create("jdbc:mysql://localhost:3306/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true", "root", "123456") FastAutoGenerator.create("jdbc:postgres://172.17.30.40:5432/aideepin?useUnicode=true&characterEncoding=utf8&serverTimezone=GMT%2B8&tinyInt1isBit=false&allowMultiQueries=true", "postgres", "postgres")
.globalConfig(builder -> { .globalConfig(builder -> {
builder.author("moyz") // 设置作者 builder.author("moyz") // 设置作者
.enableSwagger() // 开启 swagger 模式 .enableSwagger() // 开启 swagger 模式
@ -35,7 +35,7 @@ public class CodeGenerator {
.pathInfo(Collections.singletonMap(OutputFile.xml, "D://mybatisplus-generatorcode")); // 设置mapperXml生成路径 .pathInfo(Collections.singletonMap(OutputFile.xml, "D://mybatisplus-generatorcode")); // 设置mapperXml生成路径
}) })
.strategyConfig(builder -> { .strategyConfig(builder -> {
builder.addInclude("adi_user,adi_conversation,adi_conversation_message") // 设置需要生成的表名 builder.addInclude("adi_knowledge_base_qa_record") // 设置需要生成的表名
.addTablePrefix("adi_"); .addTablePrefix("adi_");
builder.mapperBuilder().enableBaseResultMap().enableMapperAnnotation().build(); builder.mapperBuilder().enableBaseResultMap().enableMapperAnnotation().build();
}) })

View File

@ -1,18 +1,18 @@
package com.moyz.adi.common.config; package com.moyz.adi.common.config;
import com.baomidou.mybatisplus.annotation.DbType; import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.config.GlobalConfig; import com.baomidou.mybatisplus.core.MybatisConfiguration;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.BlockAttackInnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.BlockAttackInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean; import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.pgvector.PGvector;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.session.SqlSessionFactory; import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -34,7 +34,7 @@ public class BeanConfig {
@Bean @Bean
public RestTemplate restTemplate() { public RestTemplate restTemplate() {
log.info("Configuration==create restTemplate"); log.info("Configuration:create restTemplate");
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
// 设置建立连接超时时间 毫秒 // 设置建立连接超时时间 毫秒
requestFactory.setConnectTimeout(60000); requestFactory.setConnectTimeout(60000);
@ -50,9 +50,9 @@ public class BeanConfig {
@Bean @Bean
@Primary @Primary
public ObjectMapper objectMapper(Jackson2ObjectMapperBuilder builder) { public ObjectMapper objectMapper() {
log.info("Configuration==create objectMapper"); log.info("Configuration:create objectMapper");
ObjectMapper objectMapper = builder.createXmlMapper(false).build(); ObjectMapper objectMapper = new Jackson2ObjectMapperBuilder().createXmlMapper(false).build();
objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module()); objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module());
//设置null值不参与序列化(字段不被显示) //设置null值不参与序列化(字段不被显示)
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
@ -88,12 +88,18 @@ public class BeanConfig {
bean.setDataSource(dataSource); bean.setDataSource(dataSource);
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor(); MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
// 分页插件 // 分页插件
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.POSTGRE_SQL));
// 防止全表更新 // 防止全表更新
interceptor.addInnerInterceptor(new BlockAttackInnerInterceptor()); interceptor.addInnerInterceptor(new BlockAttackInnerInterceptor());
bean.setPlugins(interceptor); bean.setPlugins(interceptor);
bean.setMapperLocations( bean.setMapperLocations(
new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml")); new PathMatchingResourcePatternResolver().getResources("classpath*:/mapper/*.xml"));
MybatisConfiguration configuration = bean.getConfiguration();
if(null == configuration){
configuration = new MybatisConfiguration();
bean.setConfiguration(configuration);
}
bean.getConfiguration().getTypeHandlerRegistry().register(PGvector.class, PostgresVectorTypeHandler.class);
return bean.getObject(); return bean.getObject();
} }

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
@Configuration
@ConfigurationProperties("adi.dev-mock")
@Data
public class DevMockProperty {
}

View File

@ -0,0 +1,45 @@
package com.moyz.adi.common.config;
import com.pgvector.PGvector;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
public class PostgresVectorTypeHandler extends BaseTypeHandler<PGvector> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i, PGvector parameter, JdbcType jdbcType)
throws SQLException {
ps.setObject(i, parameter);
// ps.setArray(i, ps.getConnection().createArrayOf("float", parameter));
}
@Override
public PGvector getNullableResult(ResultSet rs, String columnName) throws SQLException {
return toFloatArray(rs.getArray(columnName));
}
@Override
public PGvector getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
return toFloatArray(rs.getArray(columnIndex));
}
@Override
public PGvector getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
return toFloatArray(cs.getArray(columnIndex));
}
private PGvector toFloatArray(java.sql.Array sqlArray) throws SQLException {
PGvector pGvector = new PGvector(new float[0]);
if (sqlArray == null) {
return pGvector;
}
pGvector.setValue(sqlArray.toString());
return pGvector;
}
}

View File

@ -54,7 +54,6 @@ public class AdiConstant {
public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024"); public static final List<String> OPENAI_CREATE_IMAGE_SIZES = List.of("256x256", "512x512", "1024x1024");
public static class GenerateImage { public static class GenerateImage {
public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1; public static final int INTERACTING_METHOD_GENERATE_IMAGE = 1;
public static final int INTERACTING_METHOD_EDIT_IMAGE = 2; public static final int INTERACTING_METHOD_EDIT_IMAGE = 2;
@ -78,4 +77,6 @@ public class AdiConstant {
public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly"; public static final String QUOTA_BY_IMAGE_MONTHLY = "quota_by_image_monthly";
} }
public static final String[] POI_DOC_TYPES = {"doc", "docx", "ppt", "pptx", "xls", "xlsx"};
} }

View File

@ -78,4 +78,17 @@ public class RedisKeyConstant {
* : 用户id用于校验后续流程中的重置密码使用 * : 用户id用于校验后续流程中的重置密码使用
*/ */
public static final String FIND_MY_PASSWORD = "user:find:password:{0}"; public static final String FIND_MY_PASSWORD = "user:find:password:{0}";
/**
* qa提问次数每天
* 参数用户id:日期yyyyMMdd
* 提问数量
*/
public static final String AQ_ASK_TIMES = "qa:ask:limit:{0}:{1}";
/**
* 知识库知识点生成数量
* : 用户id
*/
public static final String qa_item_create_limit = "aq:item:create:{0}";
} }

View File

@ -21,11 +21,11 @@ public class ConvMsgResp {
private Long parentMessageId; private Long parentMessageId;
@Schema(title = "对话的消息") @Schema(title = "对话的消息")
@TableField("content") @TableField("remark")
private String content; private String remark;
@Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手") @Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手")
private String messageRole; private Integer messageRole;
@Schema(title = "消耗的token数量") @Schema(title = "消耗的token数量")
private Integer tokens; private Integer tokens;

View File

@ -0,0 +1,21 @@
package com.moyz.adi.common.dto;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import org.springframework.validation.annotation.Validated;
@Data
@Validated
public class KbEditReq {
private Long id;
private String uuid;
@NotBlank
private String title;
private String remark;
private Boolean isPublic;
}

View File

@ -0,0 +1,28 @@
package com.moyz.adi.common.dto;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import org.springframework.validation.annotation.Validated;
@Data
@Validated
public class KbItemEditReq {
private Long id;
@Min(1)
private Long kbId;
private String kbUuid;
private String uuid;
@NotBlank
private String title;
private String brief;
@NotBlank
private String remark;
}

View File

@ -0,0 +1,8 @@
package com.moyz.adi.common.dto;
import lombok.Data;
@Data
public class KbItemEmbeddingBatchReq {
private String[] uuids;
}

View File

@ -0,0 +1,12 @@
package com.moyz.adi.common.dto;
import lombok.Data;
@Data
public class KbItemEmbeddingDto {
private String embeddingId;
private float[] embedding;
private String text;
}

View File

@ -0,0 +1,13 @@
package com.moyz.adi.common.dto;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import org.springframework.validation.annotation.Validated;
@Validated
@Data
public class QAReq {
@NotBlank
private String question;
}

View File

@ -37,7 +37,4 @@ public class AdiFile extends BaseEntity {
@TableField(value = "ref_count") @TableField(value = "ref_count")
private Integer refCount; private Integer refCount;
@Schema(title = "是否删除0未删除1已删除")
@TableField(value = "is_delete")
private Boolean isDelete;
} }

View File

@ -45,7 +45,4 @@ public class AiImage extends BaseEntity {
@TableField("process_status") @TableField("process_status")
private Integer processStatus; private Integer processStatus;
@TableField("is_delete")
private Boolean isDelete;
} }

View File

@ -3,16 +3,13 @@ package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import lombok.Getter; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Setter; import lombok.Data;
import lombok.ToString;
import java.io.Serializable; import java.io.Serializable;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@Getter @Data
@Setter
@ToString
public class BaseEntity implements Serializable { public class BaseEntity implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
@ -26,4 +23,7 @@ public class BaseEntity implements Serializable {
@TableField(value = "update_time") @TableField(value = "update_time")
private LocalDateTime updateTime; private LocalDateTime updateTime;
@Schema(title = "是否删除0未删除1已删除")
@TableField(value = "is_deleted")
private Boolean isDeleted;
} }

View File

@ -45,7 +45,4 @@ public class Conversation extends BaseEntity {
@Schema(title = "set the system message to ai, ig: you are a lawyer") @Schema(title = "set the system message to ai, ig: you are a lawyer")
@TableField("ai_system_message") @TableField("ai_system_message")
private String aiSystemMessage; private String aiSystemMessage;
@TableField(value = "is_delete")
private Boolean isDelete;
} }

View File

@ -39,12 +39,12 @@ public class ConversationMessage extends BaseEntity {
private Long userId; private Long userId;
@Schema(title = "对话的消息") @Schema(title = "对话的消息")
@TableField("content") @TableField("remark")
private String content; private String remark;
@Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手") @Schema(title = "产生该消息的角色1: 用户,2:系统,3:助手")
@TableField("message_role") @TableField("message_role")
private String messageRole; private Integer messageRole;
@Schema(title = "消耗的token数量") @Schema(title = "消耗的token数量")
@TableField("tokens") @TableField("tokens")
@ -57,7 +57,4 @@ public class ConversationMessage extends BaseEntity {
@Schema(name = "上下文理解中携带的消息对数量(提示词及回复)") @Schema(name = "上下文理解中携带的消息对数量(提示词及回复)")
@TableField("understand_context_msg_pair_num") @TableField("understand_context_msg_pair_num")
private Integer understandContextMsgPairNum; private Integer understandContextMsgPairNum;
@TableField(value = "is_delete")
private Boolean isDelete;
} }

View File

@ -0,0 +1,36 @@
package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@TableName("adi_knowledge_base")
@Schema(title = "知识库实体", description = "知识库表")
public class KnowledgeBase extends BaseEntity {
@Schema(title = "uuid")
@TableField("uuid")
private String uuid;
@Schema(title = "名称")
@TableField("title")
private String title;
@Schema(title = "描述")
@TableField("remark")
private String remark;
@Schema(title = "是否公开")
@TableField("is_public")
private Boolean isPublic;
@Schema(title = "所属人id")
@TableField("owner_id")
private Long ownerId;
@Schema(title = "所属人名称")
@TableField("owner_name")
private String ownerName;
}

View File

@ -0,0 +1,25 @@
package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import com.pgvector.PGvector;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@TableName("adi_knowledge_base_embedding")
@Schema(title = "知识库-嵌入实体", description = "知识库嵌入表")
public class KnowledgeBaseEmbedding extends BaseEntity {
@Schema(title = "embedding uuid")
@TableField("embedding")
private String embeddingId;
@Schema(title = "embedding")
@TableField("embedding")
private PGvector embedding;
@Schema(title = "对应的文档")
@TableField("text")
private String text;
}

View File

@ -0,0 +1,44 @@
package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
@TableName("adi_knowledge_base_item")
@Schema(title = "知识库条目实体", description = "知识库条目表")
public class KnowledgeBaseItem extends BaseEntity {
@Schema(title = "知识库id")
@TableField("kb_id")
private Long kbId;
@Schema(title = "知识库uuid")
@TableField("kb_uuid")
private String kbUuid;
@Schema(title = "名称")
@TableField("source_file_id")
private Long sourceFileId;
@Schema(title = "uuid")
@TableField("uuid")
private String uuid;
@Schema(title = "标题")
@TableField("title")
private String title;
@Schema(title = "内容摘要")
@TableField("brief")
private String brief;
@Schema(title = "内容")
@TableField("remark")
private String remark;
@Schema(title = "是否已向量化")
@TableField("is_embedded")
private Boolean isEmbedded;
}

View File

@ -0,0 +1,41 @@
package com.moyz.adi.common.entity;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import org.apache.ibatis.type.JdbcType;
@Data
@TableName("adi_knowledge_base_qa_record")
@Schema(title = "知识库问答记录实体", description = "知识库问答记录表")
public class KnowledgeBaseQaRecord extends BaseEntity {
@Schema(title = "uuid")
@TableField(value = "uuid", jdbcType = JdbcType.VARCHAR)
private String uuid;
@Schema(title = "知识库id")
@TableField("kb_id")
private Long kbId;
@Schema(title = "知识库uuid")
@TableField("kb_uuid")
private String kbUuid;
@Schema(title = "来源文档id,以逗号隔开")
@TableField("source_file_ids")
private String sourceFileIds;
@Schema(title = "问题")
@TableField("question")
private String question;
@Schema(title = "答案")
@TableField("answer")
private String answer;
@Schema(title = "提问用户id")
@TableField("user_id")
private Long userId;
}

View File

@ -22,7 +22,4 @@ public class Prompt extends BaseEntity {
@TableField(value = "prompt") @TableField(value = "prompt")
private String prompt; private String prompt;
@Schema(title = "是否删除0未删除1已删除")
@TableField(value = "is_delete")
private Boolean isDelete;
} }

View File

@ -17,7 +17,4 @@ public class SysConfig extends BaseEntity {
@Schema(title = "配置项的值") @Schema(title = "配置项的值")
private String value; private String value;
@Schema(title = "是否删除0未删除1已删除")
@TableField(value = "is_delete")
private Boolean isDelete;
} }

View File

@ -62,6 +62,7 @@ public class User extends BaseEntity {
@TableField("active_time") @TableField("active_time")
private LocalDateTime activeTime; private LocalDateTime activeTime;
@TableField("is_delete") @Schema(title = "是否管理员01")
private Boolean isDelete; @TableField(value = "is_admin")
private Boolean isAdmin;
} }

View File

@ -0,0 +1,18 @@
package com.moyz.adi.common.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum ChatMessageRoleEnum implements BaseEnum {
USER(1, "user"),
SYSTEM(2, "system"),
ASSISTANT(3, "assistant");
private final Integer value;
private final String desc;
}

View File

@ -17,6 +17,11 @@ public enum ErrorEnum {
A_REGISTER_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"), A_REGISTER_USER_EXIST("A0013", "账号已经存在,请使用账号密码登录"),
A_FIND_PASSWORD_CODE_ERROR("A0014", "重置码已过期或不存在"), A_FIND_PASSWORD_CODE_ERROR("A0014", "重置码已过期或不存在"),
A_USER_WAIT_CONFIRM("A0015", "用户未激活"), A_USER_WAIT_CONFIRM("A0015", "用户未激活"),
A_USER_NOT_AUTH("A0016", "用户无权限"),
A_DATA_NOT_FOUND("A0017", "数据不存在"),
A_UPLOAD_FAIL("A0018", "上传失败"),
A_QA_ASK_LIMIT("A0019", "请求次数太多"),
A_QA_ITEM_LIMIT("A0020", "知识点生成已超额度"),
B_UNCAUGHT_ERROR("B0001", "未捕捉异常"), B_UNCAUGHT_ERROR("B0001", "未捕捉异常"),
B_COMMON_ERROR("B0002", "业务出错"), B_COMMON_ERROR("B0002", "业务出错"),
B_GLOBAL_ERROR("B0003", "全局异常"), B_GLOBAL_ERROR("B0003", "全局异常"),

View File

@ -29,7 +29,7 @@ import static org.springframework.http.HttpHeaders.AUTHORIZATION;
public class TokenFilter extends OncePerRequestFilter { public class TokenFilter extends OncePerRequestFilter {
public static final String[] EXCLUDE_API = { public static final String[] EXCLUDE_API = {
"/auth/", "/auth/"
}; };
@Resource @Resource

View File

@ -0,0 +1,156 @@
package com.moyz.adi.common.helper;
import com.moyz.adi.common.util.AdiPgVectorEmbeddingStore;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.util.stream.Collectors.joining;
@Slf4j
@Service
public class EmbeddingHelper {
@Value("${spring.datasource.url}")
private String dataBaseUrl;
@Value("${spring.datasource.username}")
private String dataBaseUserName;
@Value("${spring.datasource.password}")
private String dataBasePassword;
@Value("${openai.proxy.enable:false}")
private boolean proxyEnable;
@Value("${openai.proxy.host:0}")
private String proxyHost;
@Value("${openai.proxy.http-port:0}")
private int proxyHttpPort;
private static final PromptTemplate promptTemplate = PromptTemplate.from("尽可能准确地回答下面的问题: {{question}}\n\n根据以下知识库的内容:\n{{information}}");
@Resource
private OpenAiHelper openAiHelper;
private EmbeddingModel embeddingModel;
private EmbeddingStore<TextSegment> embeddingStore;
private ChatLanguageModel chatLanguageModel;
public void init() {
log.info("initEmbeddingModel");
embeddingModel = new AllMiniLmL6V2EmbeddingModel();
embeddingStore = initEmbeddingStore();
chatLanguageModel = initChatLanguageModel();
}
private EmbeddingStore<TextSegment> initEmbeddingStore() {
// 正则表达式匹配
String regex = "jdbc:postgresql://([^:/]+):(\\d+)/(\\w+).+";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(dataBaseUrl);
String host = "";
String port = "";
String databaseName = "";
if (matcher.matches()) {
host = matcher.group(1);
port = matcher.group(2);
databaseName = matcher.group(3);
System.out.println("Host: " + host);
System.out.println("Port: " + port);
System.out.println("Database: " + databaseName);
} else {
throw new RuntimeException("parse url error");
}
AdiPgVectorEmbeddingStore embeddingStore = AdiPgVectorEmbeddingStore.builder()
.host(host)
.port(Integer.parseInt(port))
.database(databaseName)
.user(dataBaseUserName)
.password(dataBasePassword)
.dimension(384)
.createTable(true)
.dropTableFirst(false)
.table("adi_knowledge_base_embedding")
.build();
return embeddingStore;
}
private ChatLanguageModel initChatLanguageModel() {
OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder().apiKey(openAiHelper.getSecretKey());
if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
builder.proxy(proxy);
}
return builder.build();
}
public EmbeddingStoreIngestor getEmbeddingStoreIngestor() {
DocumentSplitter documentSplitter = DocumentSplitters.recursive(1000, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
EmbeddingStoreIngestor embeddingStoreIngestor = EmbeddingStoreIngestor.builder()
.documentSplitter(documentSplitter)
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
return embeddingStoreIngestor;
}
public String findAnswer(String kbUuid, String question) {
// Embed the question
Embedding questionEmbedding = embeddingModel.embed(question).content();
// Find relevant embeddings in embedding store by semantic similarity
// You can play with parameters below to find a sweet spot for your specific use case
int maxResults = 3;
double minScore = 0.6;
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = ((AdiPgVectorEmbeddingStore) embeddingStore).findRelevantByKbUuid(kbUuid, questionEmbedding, maxResults, minScore);
// Create a prompt for the model that includes question and relevant embeddings
String information = relevantEmbeddings.stream()
.map(match -> match.embedded().text())
.collect(joining("\n\n"));
if (StringUtils.isBlank(information)) {
return StringUtils.EMPTY;
}
Prompt prompt = promptTemplate.apply(Map.of("question", question, "information", Matcher.quoteReplacement(information)));
AiMessage aiMessage = chatLanguageModel.generate(prompt.toUserMessage()).content();
// See an answer from the model
return aiMessage.text();
}
}

View File

@ -1,30 +1,34 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.helper;
import com.didalgo.gpt3.ChatFormatDescriptor;
import com.didalgo.gpt3.Encoding;
import com.didalgo.gpt3.GPT3Tokenizer;
import com.didalgo.gpt3.TokenCount;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.entity.AiImage; import com.moyz.adi.common.entity.AiImage;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.model.AnswerMeta; import com.moyz.adi.common.interfaces.IChatAssistant;
import com.moyz.adi.common.model.ChatMeta;
import com.moyz.adi.common.model.QuestionMeta;
import com.moyz.adi.common.service.FileService; import com.moyz.adi.common.service.FileService;
import com.moyz.adi.common.service.SysConfigService; import com.moyz.adi.common.service.SysConfigService;
import com.moyz.adi.common.util.FileUtil;
import com.moyz.adi.common.util.ImageUtil; import com.moyz.adi.common.util.ImageUtil;
import com.moyz.adi.common.util.JsonUtil; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.TriConsumer; import com.moyz.adi.common.util.TriConsumer;
import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.ChatMeta;
import com.moyz.adi.common.vo.QuestionMeta;
import com.moyz.adi.common.vo.SseAskParams;
import com.theokanning.openai.OpenAiApi; import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.completion.chat.ChatCompletionChoice; import com.theokanning.openai.image.CreateImageEditRequest;
import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.image.CreateImageVariationRequest;
import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.image.*;
import com.theokanning.openai.service.OpenAiService; import com.theokanning.openai.service.OpenAiService;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.openai.OpenAiImageModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient; import okhttp3.OkHttpClient;
@ -35,19 +39,23 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import retrofit2.Retrofit; import retrofit2.Retrofit;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Proxy; import java.net.Proxy;
import java.time.Duration; import java.time.Duration;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_RESP_FORMATS_URL; import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_RESP_FORMATS_URL;
import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES; import static com.moyz.adi.common.cosntant.AdiConstant.OPENAI_CREATE_IMAGE_SIZES;
import static com.theokanning.openai.service.OpenAiService.defaultClient; import static com.theokanning.openai.service.OpenAiService.defaultClient;
import static com.theokanning.openai.service.OpenAiService.defaultRetrofit; import static com.theokanning.openai.service.OpenAiService.defaultRetrofit;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_1024_x_1024;
import static dev.ai4j.openai4j.image.ImageModel.DALL_E_SIZE_512_x_512;
import static dev.langchain4j.model.openai.OpenAiModelName.DALL_E_2;
@Slf4j @Slf4j
@Service @Service
@ -62,21 +70,23 @@ public class OpenAiHelper {
@Value("${openai.proxy.http-port:0}") @Value("${openai.proxy.http-port:0}")
private int proxyHttpPort; private int proxyHttpPort;
@Value("${local.images}")
private String localImagesPath;
@Resource @Resource
private FileService fileService; private FileService fileService;
@Resource @Resource
private ObjectMapper objectMapper; private ObjectMapper objectMapper;
public OpenAiService getOpenAiService(User user) { public String getSecretKey() {
String secretKey = SysConfigService.getSecretKey(); String secretKey = SysConfigService.getSecretKey();
String userSecretKey = user.getSecretKey(); User user = ThreadContext.getCurrentUser();
if (StringUtils.isNotBlank(userSecretKey)) { if (null != user && StringUtils.isNotBlank(user.getSecretKey())) {
secretKey = userSecretKey; secretKey = user.getSecretKey();
} }
return secretKey;
}
public OpenAiService getOpenAiService() {
String secretKey = getSecretKey();
if (proxyEnable) { if (proxyEnable) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort)); Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS)) OkHttpClient client = defaultClient(secretKey, Duration.of(60, ChronoUnit.SECONDS))
@ -90,102 +100,121 @@ public class OpenAiHelper {
return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS)); return new OpenAiService(secretKey, Duration.of(60, ChronoUnit.SECONDS));
} }
/** public IChatAssistant getChatAssistant(ChatMemory chatMemory) {
* Send http request to openai server <br/> String secretKey = getSecretKey();
* Calculate token OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder = OpenAiStreamingChatModel.builder();
* if (proxyEnable) {
* @param user Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
* @param regenerateQuestionUuid builder.proxy(proxy);
* @param chatMessageList }
* @param sseEmitter builder.apiKey(secretKey).timeout(Duration.of(60, ChronoUnit.SECONDS));
* @param consumer AiServices<IChatAssistant> serviceBuilder = AiServices.builder(IChatAssistant.class)
*/ .streamingChatLanguageModel(builder.build());
public void sseAsk(User user, String regenerateQuestionUuid, List<ChatMessage> chatMessageList, SseEmitter sseEmitter, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) { if (null != chatMemory) {
final int[] answerTokens = {0}; serviceBuilder.chatMemory(chatMemory);
StringBuilder response = new StringBuilder(); }
OpenAiService service = getOpenAiService(user); return serviceBuilder.build();
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest }
.builder()
.model(AdiConstant.DEFAULT_MODEL) public ImageModel getImageModel(User user, String size) {
.messages(chatMessageList) String secretKey = getSecretKey();
.n(1) if (proxyEnable) {
.logitBias(new HashMap<>()) Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(proxyHost, proxyHttpPort));
return OpenAiImageModel.builder()
.modelName(DALL_E_2)
.apiKey(secretKey)
.user(user.getUuid())
.responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)
.size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512))
.logRequests(true)
.logResponses(true)
.withPersisting(false)
.maxRetries(2)
.proxy(proxy)
.build(); .build();
service.streamChatCompletion(chatCompletionRequest) }
.doOnError(onError -> { return OpenAiImageModel.builder()
log.error("openai error", onError); .modelName(DALL_E_2)
sseEmitter.send(SseEmitter.event().name("error").data(onError.getMessage())); .apiKey(secretKey)
sseEmitter.complete(); .user(user.getUuid())
}).subscribe(completionChunk -> { .responseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL)
answerTokens[0]++; .size(StringUtils.defaultString(size, DALL_E_SIZE_512_x_512))
List<ChatCompletionChoice> choices = completionChunk.getChoices(); .logRequests(true)
String content = choices.get(0).getMessage().getContent(); .logResponses(true)
.withPersisting(false)
.maxRetries(2)
.build();
}
/**
* Send http request to llm server
*/
public void sseAsk(SseAskParams params, TriConsumer<String, QuestionMeta, AnswerMeta> consumer) {
IChatAssistant chatAssistant = getChatAssistant(params.getChatMemory());
TokenStream tokenStream;
if (StringUtils.isNotBlank(params.getSystemMessage())) {
tokenStream = chatAssistant.chat(params.getSystemMessage(), params.getUserMessage());
} else {
tokenStream = chatAssistant.chat(params.getUserMessage());
}
tokenStream.onNext((content) -> {
log.info("get content:{}", content); log.info("get content:{}", content);
if (null == content && response.isEmpty()) { //加空格配合前端的fetchEventSource进行解析见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133
return;
}
if (null == content || AdiConstant.OPENAI_MESSAGE_DONE_FLAG.equals(content)) {
log.info("OpenAI返回数据结束了");
sseEmitter.send(AdiConstant.OPENAI_MESSAGE_DONE_FLAG);
GPT3Tokenizer tokenizer = new GPT3Tokenizer(Encoding.CL100K_BASE);
int questionTokens = 0;
try { try {
questionTokens = TokenCount.fromMessages(chatMessageList, tokenizer, ChatFormatDescriptor.forModel(AdiConstant.DEFAULT_MODEL)); params.getSseEmitter().send(" " + content);
} catch (IllegalArgumentException e) { } catch (IOException e) {
log.error("该模型的token无法统计,model:{}", AdiConstant.DEFAULT_MODEL); log.error("stream onNext error", e);
} }
System.out.println("requestTokens:" + questionTokens); })
System.out.println("返回内容:" + response); .onComplete((response) -> {
log.info("返回数据结束了:{}", response);
String questionUuid = StringUtils.isNotBlank(regenerateQuestionUuid) ? regenerateQuestionUuid : UUID.randomUUID().toString().replace("-", ""); String questionUuid = StringUtils.isNotBlank(params.getRegenerateQuestionUuid()) ? params.getRegenerateQuestionUuid() : UUID.randomUUID().toString().replace("-", "");
QuestionMeta questionMeta = new QuestionMeta(questionTokens, questionUuid); QuestionMeta questionMeta = new QuestionMeta(response.tokenUsage().inputTokenCount(), questionUuid);
AnswerMeta answerMeta = new AnswerMeta(answerTokens[0], UUID.randomUUID().toString().replace("-", "")); AnswerMeta answerMeta = new AnswerMeta(response.tokenUsage().outputTokenCount(), UUID.randomUUID().toString().replace("-", ""));
ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta); ChatMeta chatMeta = new ChatMeta(questionMeta, answerMeta);
// String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", ""); String meta = JsonUtil.toJson(chatMeta).replaceAll("\r\n", "");
log.info("meta:" + meta); log.info("meta:" + meta);
sseEmitter.send(" [META]" + meta); try {
params.getSseEmitter().send(" [META]" + meta);
} catch (IOException e) {
log.error("stream onComplete error", e);
throw new RuntimeException(e);
}
// close eventSourceEmitter after tokens was calculated // close eventSourceEmitter after tokens was calculated
sseEmitter.complete(); params.getSseEmitter().complete();
consumer.accept(response.toString(), questionMeta, answerMeta); consumer.accept(response.content().text(), questionMeta, answerMeta);
return; })
.onError((error) -> {
log.error("stream error", error);
try {
params.getSseEmitter().send(SseEmitter.event().name("error").data(error.getMessage()));
} catch (IOException e) {
log.error("sse error", e);
} }
//加空格配合前端的fetchEventSource进行解析见https://github.com/Azure/fetch-event-source/blob/45ac3cfffd30b05b79fbf95c21e67d4ef59aa56a/src/parse.ts#L129-L133 params.getSseEmitter().complete();
sseEmitter.send(" " + content); })
response.append(content); .start();
});
System.out.println("返回内容1111" + response);
} }
public List<Image> createImage(User user, AiImage aiImage) { public List<String> createImage(User user, AiImage aiImage) {
if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) { if (aiImage.getGenerateNumber() < 1 || aiImage.getGenerateNumber() > 10) {
throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR); throw new BaseException(ErrorEnum.A_IMAGE_NUMBER_ERROR);
} }
if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) { if (!OPENAI_CREATE_IMAGE_SIZES.contains(aiImage.getGenerateSize())) {
throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR); throw new BaseException(ErrorEnum.A_IMAGE_SIZE_ERROR);
} }
OpenAiService service = getOpenAiService(user); ImageModel imageModel = getImageModel(user, aiImage.getGenerateSize());
CreateImageRequest createImageRequest = new CreateImageRequest();
createImageRequest.setPrompt(aiImage.getPrompt());
createImageRequest.setN(aiImage.getGenerateNumber());
createImageRequest.setSize(aiImage.getGenerateSize());
createImageRequest.setResponseFormat(OPENAI_CREATE_IMAGE_RESP_FORMATS_URL);
createImageRequest.setUser(user.getUuid());
try { try {
ImageResult imageResult = service.createImage(createImageRequest); Response<List<Image>> response = imageModel.generate(aiImage.getPrompt(), aiImage.getGenerateNumber());
log.info("createImage response:{}", imageResult); log.info("createImage response:{}", response);
return imageResult.getData(); return response.content().stream().map(item -> item.url().toString()).collect(Collectors.toList());
} catch (Exception e) { } catch (Exception e) {
log.error("create image error", e); log.error("create image error", e);
} }
return Collections.emptyList(); return Collections.emptyList();
} }
public List<Image> editImage(User user, AiImage aiImage) { public List<String> editImage(User user, AiImage aiImage) {
File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage())); File originalFile = new File(fileService.getImagePath(aiImage.getOriginalImage()));
File maskFile = null; File maskFile = null;
if (StringUtils.isNotBlank(aiImage.getMaskImage())) { if (StringUtils.isNotBlank(aiImage.getMaskImage())) {
@ -193,7 +222,7 @@ public class OpenAiHelper {
} }
//如果不是RGBA类型的图片先转成RGBA //如果不是RGBA类型的图片先转成RGBA
File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage())); File rgbaOriginalImage = ImageUtil.rgbConvertToRgba(originalFile, fileService.getTmpImagesPath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService(user); OpenAiService service = getOpenAiService();
CreateImageEditRequest request = new CreateImageEditRequest(); CreateImageEditRequest request = new CreateImageEditRequest();
request.setPrompt(aiImage.getPrompt()); request.setPrompt(aiImage.getPrompt());
request.setN(aiImage.getGenerateNumber()); request.setN(aiImage.getGenerateNumber());
@ -203,16 +232,16 @@ public class OpenAiHelper {
try { try {
ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile); ImageResult imageResult = service.createImageEdit(request, rgbaOriginalImage, maskFile);
log.info("editImage response:{}", imageResult); log.info("editImage response:{}", imageResult);
return imageResult.getData(); return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) { } catch (Exception e) {
log.error("edit image error", e); log.error("edit image error", e);
} }
return Collections.emptyList(); return Collections.emptyList();
} }
public List<Image> createImageVariation(User user, AiImage aiImage) { public List<String> createImageVariation(User user, AiImage aiImage) {
File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage())); File imagePath = new File(fileService.getImagePath(aiImage.getOriginalImage()));
OpenAiService service = getOpenAiService(user); OpenAiService service = getOpenAiService();
CreateImageVariationRequest request = new CreateImageVariationRequest(); CreateImageVariationRequest request = new CreateImageVariationRequest();
request.setN(aiImage.getGenerateNumber()); request.setN(aiImage.getGenerateNumber());
request.setSize(aiImage.getGenerateSize()); request.setSize(aiImage.getGenerateSize());
@ -221,7 +250,7 @@ public class OpenAiHelper {
try { try {
ImageResult imageResult = service.createImageVariation(request, imagePath); ImageResult imageResult = service.createImageVariation(request, imagePath);
log.info("createImageVariation response:{}", imageResult); log.info("createImageVariation response:{}", imageResult);
return imageResult.getData(); return imageResult.getData().stream().map(item -> item.getUrl()).collect(Collectors.toList());
} catch (Exception e) { } catch (Exception e) {
log.error("image variation error", e); log.error("image variation error", e);
} }

View File

@ -2,7 +2,7 @@ package com.moyz.adi.common.helper;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.model.CostStat; import com.moyz.adi.common.vo.CostStat;
import com.moyz.adi.common.service.UserDayCostService; import com.moyz.adi.common.service.UserDayCostService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -1,6 +1,6 @@
package com.moyz.adi.common.helper; package com.moyz.adi.common.helper;
import com.moyz.adi.common.model.RequestRateLimit; import com.moyz.adi.common.vo.RequestRateLimit;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate;

View File

@ -0,0 +1,14 @@
package com.moyz.adi.common.interfaces;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
public interface IChatAssistant {
@SystemMessage("{{sm}}")
TokenStream chat(@V("sm") String systemMessage, @UserMessage String prompt);
TokenStream chat(@UserMessage String prompt);
}

View File

@ -0,0 +1,17 @@
package com.moyz.adi.common.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.entity.KnowledgeBaseEmbedding;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface KnowledgeBaseEmbeddingMapper extends BaseMapper<KnowledgeBaseEmbedding> {
Page<KnowledgeBaseEmbedding> selectByItemUuid(Page<KnowledgeBaseEmbedding> page, @Param("kbItemUuid") String uuid);
boolean deleteByItemUuid(@Param("kbItemUuid") String uuid);
}

View File

@ -0,0 +1,9 @@
package com.moyz.adi.common.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.moyz.adi.common.entity.KnowledgeBaseItem;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface KnowledgeBaseItemMapper extends BaseMapper<KnowledgeBaseItem> {
}

View File

@ -0,0 +1,28 @@
package com.moyz.adi.common.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.moyz.adi.common.entity.KnowledgeBase;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
@Mapper
public interface KnowledgeBaseMapper extends BaseMapper<KnowledgeBase> {
/**
* 搜索知识库管理员
*
* @param keyword 关键词
* @return
*/
Page<KnowledgeBase> searchByAdmin(Page<KnowledgeBase> page, @Param("keyword") String keyword);
/**
* 搜索知识库用户
*
* @param ownerId 用户id
* @param keyword 关键词
* @return
*/
Page<KnowledgeBase> searchByUser(Page<KnowledgeBase> page, @Param("ownerId") long ownerId, @Param("keyword") String keyword, @Param("includeOthersPublic") Boolean includeOthersPublic);
}

View File

@ -0,0 +1,10 @@
package com.moyz.adi.common.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface KnowledgeBaseQaRecordMapper extends BaseMapper<KnowledgeBaseQaRecord> {
}

View File

@ -17,7 +17,6 @@ import com.moyz.adi.common.mapper.AiImageMapper;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.util.UserUtil;
import com.theokanning.openai.image.Image;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -174,7 +173,7 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId()); String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.IMAGE_RATE_LIMIT_CONFIG); rateLimitHelper.increaseRequestTimes(requestTimesKey, LocalCache.IMAGE_RATE_LIMIT_CONFIG);
List<Image> images = new ArrayList<>(); List<String> images = new ArrayList<>();
if (aiImage.getInteractingMethod() == INTERACTING_METHOD_GENERATE_IMAGE) { if (aiImage.getInteractingMethod() == INTERACTING_METHOD_GENERATE_IMAGE) {
images = openAiHelper.createImage(user, aiImage); images = openAiHelper.createImage(user, aiImage);
} else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_EDIT_IMAGE) { } else if (aiImage.getInteractingMethod() == INTERACTING_METHOD_EDIT_IMAGE) {
@ -183,8 +182,8 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
images = openAiHelper.createImageVariation(user, aiImage); images = openAiHelper.createImageVariation(user, aiImage);
} }
List<String> imageUuids = new ArrayList(); List<String> imageUuids = new ArrayList();
images.forEach(image -> { images.forEach(imageUrl -> {
String imageUuid = fileService.saveToLocal(user, image.getUrl()); String imageUuid = fileService.saveToLocal(user, imageUrl);
imageUuids.add(imageUuid); imageUuids.add(imageUuid);
}); });
String imageUuidsJoin = imageUuids.stream().collect(Collectors.joining(",")); String imageUuidsJoin = imageUuids.stream().collect(Collectors.joining(","));
@ -192,7 +191,7 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
_this.lambdaUpdate().eq(AiImage::getId, aiImage.getId()).set(AiImage::getProcessStatus, STATUS_FAIL).update(); _this.lambdaUpdate().eq(AiImage::getId, aiImage.getId()).set(AiImage::getProcessStatus, STATUS_FAIL).update();
return; return;
} }
String respImagesPath = images.stream().map(Image::getUrl).collect(Collectors.joining(",")); String respImagesPath = images.stream().collect(Collectors.joining(","));
updateAiImageStatus(aiImage.getId(), respImagesPath, imageUuidsJoin, STATUS_SUCCESS); updateAiImageStatus(aiImage.getId(), respImagesPath, imageUuidsJoin, STATUS_SUCCESS);
//Update the cost of current user //Update the cost of current user
@ -231,7 +230,7 @@ public class AiImageService extends ServiceImpl<AiImageMapper, AiImage> {
public AiImagesListResp listAll(@RequestParam Long maxId, @RequestParam int pageSize) { public AiImagesListResp listAll(@RequestParam Long maxId, @RequestParam int pageSize) {
List<AiImage> list = this.lambdaQuery() List<AiImage> list = this.lambdaQuery()
.eq(AiImage::getUserId, ThreadContext.getCurrentUserId()) .eq(AiImage::getUserId, ThreadContext.getCurrentUserId())
.eq(AiImage::getIsDelete, false) .eq(AiImage::getIsDeleted, false)
.lt(AiImage::getId, maxId) .lt(AiImage::getId, maxId)
.orderByDesc(AiImage::getId) .orderByDesc(AiImage::getId)
.last("limit " + pageSize) .last("limit " + pageSize)

View File

@ -10,19 +10,25 @@ import com.moyz.adi.common.entity.Conversation;
import com.moyz.adi.common.entity.ConversationMessage; import com.moyz.adi.common.entity.ConversationMessage;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.enums.ChatMessageRoleEnum;
import com.moyz.adi.common.enums.ErrorEnum; import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.exception.BaseException; import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.OpenAiHelper; import com.moyz.adi.common.helper.OpenAiHelper;
import com.moyz.adi.common.helper.QuotaHelper; import com.moyz.adi.common.helper.QuotaHelper;
import com.moyz.adi.common.helper.RateLimitHelper; import com.moyz.adi.common.helper.RateLimitHelper;
import com.moyz.adi.common.mapper.ConversationMessageMapper; import com.moyz.adi.common.mapper.ConversationMessageMapper;
import com.moyz.adi.common.model.AnswerMeta;
import com.moyz.adi.common.model.QuestionMeta;
import com.moyz.adi.common.util.LocalCache; import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.util.UserUtil;
import com.theokanning.openai.completion.chat.ChatMessage; import com.moyz.adi.common.vo.AnswerMeta;
import com.moyz.adi.common.vo.QuestionMeta;
import com.moyz.adi.common.vo.SseAskParams;
import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.completion.chat.ChatMessageRole;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -35,12 +41,12 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException; import java.io.IOException;
import java.text.MessageFormat; import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND; import static com.moyz.adi.common.enums.ErrorEnum.B_MESSAGE_NOT_FOUND;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
@Slf4j @Slf4j
@Service @Service
@ -90,7 +96,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
//check 2: the conversation has been deleted //check 2: the conversation has been deleted
Conversation delConv = conversationService.lambdaQuery() Conversation delConv = conversationService.lambdaQuery()
.eq(Conversation::getUuid, askReq.getConversationUuid()) .eq(Conversation::getUuid, askReq.getConversationUuid())
.eq(Conversation::getIsDelete, true) .eq(Conversation::getIsDeleted, true)
.one(); .one();
if (null != delConv) { if (null != delConv) {
sendErrorMsg(sseEmitter, "该对话已经删除"); sendErrorMsg(sseEmitter, "该对话已经删除");
@ -100,7 +106,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
//check 3: conversation quota //check 3: conversation quota
Long convsCount = conversationService.lambdaQuery() Long convsCount = conversationService.lambdaQuery()
.eq(Conversation::getUserId, user.getId()) .eq(Conversation::getUserId, user.getId())
.eq(Conversation::getIsDelete, false) .eq(Conversation::getIsDeleted, false)
.count(); .count();
long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM)); long convsMax = Integer.parseInt(LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.CONVERSATION_MAX_NUM));
if (convsCount >= convsMax) { if (convsCount >= convsMax) {
@ -175,12 +181,16 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
} }
} }
); );
SseAskParams sseAskParams = new SseAskParams();
String prompt = askReq.getPrompt(); String prompt = askReq.getPrompt();
if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) { if (StringUtils.isNotBlank(askReq.getRegenerateQuestionUuid())) {
prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getContent(); prompt = getPromptMsgByQuestionUuid(askReq.getRegenerateQuestionUuid()).getRemark();
} }
sseAskParams.setSystemMessage(StringUtils.EMPTY);
sseAskParams.setSseEmitter(sseEmitter);
sseAskParams.setUserMessage(prompt);
sseAskParams.setRegenerateQuestionUuid(askReq.getRegenerateQuestionUuid());
//questions //questions
final List<ChatMessage> chatMessageList = new ArrayList<>();
//system message //system message
Conversation conversation = conversationService.lambdaQuery() Conversation conversation = conversationService.lambdaQuery()
.eq(Conversation::getUuid, askReq.getConversationUuid()) .eq(Conversation::getUuid, askReq.getConversationUuid())
@ -188,8 +198,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
.orElse(null); .orElse(null);
if (null != conversation) { if (null != conversation) {
if (StringUtils.isNotBlank(conversation.getAiSystemMessage())) { if (StringUtils.isNotBlank(conversation.getAiSystemMessage())) {
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), conversation.getAiSystemMessage()); sseAskParams.setSystemMessage(conversation.getAiSystemMessage());
chatMessageList.add(chatMessage);
} }
//history message //history message
if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable()) && user.getUnderstandContextMsgPairNum() > 0) { if (Boolean.TRUE.equals(conversation.getUnderstandContextEnable()) && user.getUnderstandContextMsgPairNum() > 0) {
@ -200,19 +209,23 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
.last("limit " + user.getUnderstandContextMsgPairNum() * 2) .last("limit " + user.getUnderstandContextMsgPairNum() * 2)
.list(); .list();
if (!historyMsgList.isEmpty()) { if (!historyMsgList.isEmpty()) {
ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(1000, new OpenAiTokenizer(GPT_3_5_TURBO));
historyMsgList.sort(Comparator.comparing(ConversationMessage::getId)); historyMsgList.sort(Comparator.comparing(ConversationMessage::getId));
for (ConversationMessage historyMsg : historyMsgList) { for (ConversationMessage historyMsg : historyMsgList) {
ChatMessage chatMessage = new ChatMessage(historyMsg.getMessageRole(), historyMsg.getContent()); if (ChatMessageRole.USER.value().equals(historyMsg.getMessageRole())) {
chatMessageList.add(chatMessage); UserMessage userMessage = UserMessage.from(historyMsg.getRemark());
chatMemory.add(userMessage);
} else if (ChatMessageRole.SYSTEM.value().equals(historyMsg.getMessageRole())) {
SystemMessage userMessage = SystemMessage.from(historyMsg.getRemark());
chatMemory.add(userMessage);
} }
} }
sseAskParams.setChatMemory(chatMemory);
}
} }
} }
//new user message openAiHelper.sseAsk(sseAskParams, (response, questionMeta, answerMeta) -> {
ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), prompt);
chatMessageList.add(userMessage);
openAiHelper.sseAsk(user, askReq.getRegenerateQuestionUuid(), chatMessageList, sseEmitter, (response, questionMeta, answerMeta) -> {
try { try {
_this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta); _this.saveAfterAiResponse(user, askReq, response, questionMeta, answerMeta);
} catch (Exception e) { } catch (Exception e) {
@ -228,7 +241,7 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
queryWrapper.eq(ConversationMessage::getConversationId, convId); queryWrapper.eq(ConversationMessage::getConversationId, convId);
queryWrapper.eq(ConversationMessage::getParentMessageId, 0); queryWrapper.eq(ConversationMessage::getParentMessageId, 0);
queryWrapper.lt(ConversationMessage::getId, maxId); queryWrapper.lt(ConversationMessage::getId, maxId);
queryWrapper.eq(ConversationMessage::getIsDelete, false); queryWrapper.eq(ConversationMessage::getIsDeleted, false);
queryWrapper.last("limit " + pageSize); queryWrapper.last("limit " + pageSize);
queryWrapper.orderByDesc(ConversationMessage::getId); queryWrapper.orderByDesc(ConversationMessage::getId);
return getBaseMapper().selectList(queryWrapper); return getBaseMapper().selectList(queryWrapper);
@ -257,8 +270,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
question.setUuid(questionMeta.getUuid()); question.setUuid(questionMeta.getUuid());
question.setConversationId(conversation.getId()); question.setConversationId(conversation.getId());
question.setConversationUuid(convUuid); question.setConversationUuid(convUuid);
question.setMessageRole(ChatMessageRole.USER.value()); question.setMessageRole(ChatMessageRoleEnum.USER.getValue());
question.setContent(prompt); question.setRemark(prompt);
question.setTokens(questionMeta.getTokens()); question.setTokens(questionMeta.getTokens());
question.setSecretKeyType(secretKeyType); question.setSecretKeyType(secretKeyType);
question.setUnderstandContextMsgPairNum(user.getUnderstandContextMsgPairNum()); question.setUnderstandContextMsgPairNum(user.getUnderstandContextMsgPairNum());
@ -273,8 +286,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
aiAnswer.setUuid(answerMeta.getUuid()); aiAnswer.setUuid(answerMeta.getUuid());
aiAnswer.setConversationId(conversation.getId()); aiAnswer.setConversationId(conversation.getId());
aiAnswer.setConversationUuid(convUuid); aiAnswer.setConversationUuid(convUuid);
aiAnswer.setMessageRole(ChatMessageRole.ASSISTANT.value()); aiAnswer.setMessageRole(ChatMessageRoleEnum.ASSISTANT.getValue());
aiAnswer.setContent(response); aiAnswer.setRemark(response);
aiAnswer.setTokens(answerMeta.getTokens()); aiAnswer.setTokens(answerMeta.getTokens());
aiAnswer.setParentMessageId(promptMsg.getId()); aiAnswer.setParentMessageId(promptMsg.getId());
aiAnswer.setSecretKeyType(secretKeyType); aiAnswer.setSecretKeyType(secretKeyType);
@ -321,8 +334,8 @@ public class ConversationMessageService extends ServiceImpl<ConversationMessageM
return this.lambdaUpdate() return this.lambdaUpdate()
.eq(ConversationMessage::getUuid, uuid) .eq(ConversationMessage::getUuid, uuid)
.eq(ConversationMessage::getUserId, ThreadContext.getCurrentUserId()) .eq(ConversationMessage::getUserId, ThreadContext.getCurrentUserId())
.eq(ConversationMessage::getIsDelete, false) .eq(ConversationMessage::getIsDeleted, false)
.set(ConversationMessage::getIsDelete, true) .set(ConversationMessage::getIsDeleted, true)
.update(); .update();
} }

View File

@ -35,7 +35,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
public List<ConvDto> listByUser() { public List<ConvDto> listByUser() {
List<Conversation> list = this.lambdaQuery() List<Conversation> list = this.lambdaQuery()
.eq(Conversation::getUserId, ThreadContext.getCurrentUserId()) .eq(Conversation::getUserId, ThreadContext.getCurrentUserId())
.eq(Conversation::getIsDelete, false) .eq(Conversation::getIsDeleted, false)
.orderByDesc(Conversation::getId) .orderByDesc(Conversation::getId)
.last("limit " + sysConfigService.getConversationMaxNum()) .last("limit " + sysConfigService.getConversationMaxNum())
.list(); .list();
@ -61,7 +61,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
ConversationMessage maxMsg = conversationMessageService.lambdaQuery() ConversationMessage maxMsg = conversationMessageService.lambdaQuery()
.select(ConversationMessage::getId) .select(ConversationMessage::getId)
.eq(ConversationMessage::getUuid, maxMsgUuid) .eq(ConversationMessage::getUuid, maxMsgUuid)
.eq(ConversationMessage::getIsDelete, false) .eq(ConversationMessage::getIsDeleted, false)
.one(); .one();
if (null == maxMsg) { if (null == maxMsg) {
throw new RuntimeException("找不到对应的消息"); throw new RuntimeException("找不到对应的消息");
@ -88,7 +88,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
List<ConversationMessage> childMessages = conversationMessageService List<ConversationMessage> childMessages = conversationMessageService
.lambdaQuery() .lambdaQuery()
.in(ConversationMessage::getParentMessageId, parentIds) .in(ConversationMessage::getParentMessageId, parentIds)
.eq(ConversationMessage::getIsDelete, false) .eq(ConversationMessage::getIsDeleted, false)
.list(); .list();
Map<Long, List<ConversationMessage>> idToMessages = childMessages.stream().collect(Collectors.groupingBy(ConversationMessage::getParentMessageId)); Map<Long, List<ConversationMessage>> idToMessages = childMessages.stream().collect(Collectors.groupingBy(ConversationMessage::getParentMessageId));
@ -126,7 +126,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
Conversation conversation = this.lambdaQuery() Conversation conversation = this.lambdaQuery()
.eq(Conversation::getUuid, uuid) .eq(Conversation::getUuid, uuid)
.eq(Conversation::getUserId, ThreadContext.getCurrentUserId()) .eq(Conversation::getUserId, ThreadContext.getCurrentUserId())
.eq(Conversation::getIsDelete, false) .eq(Conversation::getIsDeleted, false)
.one(); .one();
if (null == conversation) { if (null == conversation) {
throw new BaseException(A_CONVERSATION_NOT_EXIST); throw new BaseException(A_CONVERSATION_NOT_EXIST);
@ -144,7 +144,7 @@ public class ConversationService extends ServiceImpl<ConversationMapper, Convers
return this.lambdaUpdate() return this.lambdaUpdate()
.eq(Conversation::getUuid, uuid) .eq(Conversation::getUuid, uuid)
.eq(Conversation::getUserId, ThreadContext.getCurrentUserId()) .eq(Conversation::getUserId, ThreadContext.getCurrentUserId())
.set(Conversation::getIsDelete, true) .set(Conversation::getIsDeleted, true)
.update(); .update();
} }

View File

@ -38,14 +38,14 @@ public class FileService extends ServiceImpl<FileMapper, AdiFile> {
@Value("${local.tmp_images}") @Value("${local.tmp_images}")
private String tmpImagesPath; private String tmpImagesPath;
public String writeToLocal(MultipartFile file) { public AdiFile writeToLocal(MultipartFile file) {
String md5 = MD5Utils.md5ByMultipartFile(file); String md5 = MD5Utils.md5ByMultipartFile(file);
Optional<AdiFile> existFile = this.lambdaQuery() Optional<AdiFile> existFile = this.lambdaQuery()
.eq(AdiFile::getMd5, md5) .eq(AdiFile::getMd5, md5)
.eq(AdiFile::getIsDelete, false) .eq(AdiFile::getIsDeleted, false)
.oneOpt(); .oneOpt();
if (existFile.isPresent()) { if (existFile.isPresent()) {
return existFile.get().getUuid(); return existFile.get();
} }
String uuid = UUID.randomUUID().toString().replace("-", ""); String uuid = UUID.randomUUID().toString().replace("-", "");
Pair<String, String> originalFile = FileUtil.saveToLocal(file, imagePath, uuid); Pair<String, String> originalFile = FileUtil.saveToLocal(file, imagePath, uuid);
@ -56,7 +56,7 @@ public class FileService extends ServiceImpl<FileMapper, AdiFile> {
adiFile.setExt(originalFile.getRight()); adiFile.setExt(originalFile.getRight());
adiFile.setUserId(ThreadContext.getCurrentUserId()); adiFile.setUserId(ThreadContext.getCurrentUserId());
this.getBaseMapper().insert(adiFile); this.getBaseMapper().insert(adiFile);
return uuid; return adiFile;
} }
public String saveToLocal(User user, String sourceImageUrl) { public String saveToLocal(User user, String sourceImageUrl) {
@ -83,7 +83,7 @@ public class FileService extends ServiceImpl<FileMapper, AdiFile> {
return this.lambdaUpdate() return this.lambdaUpdate()
.eq(AdiFile::getUserId, ThreadContext.getCurrentUserId()) .eq(AdiFile::getUserId, ThreadContext.getCurrentUserId())
.eq(AdiFile::getUuid, uuid) .eq(AdiFile::getUuid, uuid)
.set(AdiFile::getIsDelete, true) .set(AdiFile::getIsDeleted, true)
.update(); .update();
} }

View File

@ -0,0 +1,22 @@
package com.moyz.adi.common.service;
import com.moyz.adi.common.helper.EmbeddingHelper;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Service;
@Service
public class Initializer {
@Resource
private SysConfigService sysConfigService;
@Resource
private EmbeddingHelper embeddingHelper;
@PostConstruct
public void init(){
sysConfigService.reload();
embeddingHelper.init();
}
}

View File

@ -0,0 +1,29 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.moyz.adi.common.dto.KbItemEmbeddingDto;
import com.moyz.adi.common.entity.KnowledgeBaseEmbedding;
import com.moyz.adi.common.mapper.KnowledgeBaseEmbeddingMapper;
import com.moyz.adi.common.util.MPPageUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class KnowledgeBaseEmbeddingService extends ServiceImpl<KnowledgeBaseEmbeddingMapper, KnowledgeBaseEmbedding> {
public Page<KbItemEmbeddingDto> listByItemUuid(String kbItemUuid, int currentPage, int pageSize) {
Page<KnowledgeBaseEmbedding> sourcePage = baseMapper.selectByItemUuid(new Page<>(currentPage, pageSize), kbItemUuid);
Page<KbItemEmbeddingDto> result = new Page<>();
MPPageUtil.convertTo(sourcePage, result, KbItemEmbeddingDto.class, (source, target) -> {
target.setEmbedding(source.getEmbedding().toArray());
return target;
});
return result;
}
public boolean deleteByItemUuid(String kbItemUuid){
return baseMapper.deleteByItemUuid(kbItemUuid);
}
}

View File

@ -0,0 +1,152 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.dto.KbItemEditReq;
import com.moyz.adi.common.entity.KnowledgeBase;
import com.moyz.adi.common.entity.KnowledgeBaseItem;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.EmbeddingHelper;
import com.moyz.adi.common.mapper.KnowledgeBaseItemMapper;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.Metadata;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.Optional;
import java.util.UUID;
import static com.moyz.adi.common.enums.ErrorEnum.*;
@Slf4j
@Service
public class KnowledgeBaseItemService extends ServiceImpl<KnowledgeBaseItemMapper, KnowledgeBaseItem> {
@Resource
private EmbeddingHelper embeddingHelper;
@Resource
private KnowledgeBaseEmbeddingService knowledgeBaseEmbeddingService;
@Lazy
@Resource
private KnowledgeBaseService knowledgeBaseService;
public KnowledgeBaseItem saveOrUpdate(KbItemEditReq itemEditReq) {
String uuid = itemEditReq.getUuid();
KnowledgeBaseItem item = new KnowledgeBaseItem();
item.setTitle(itemEditReq.getTitle());
if (StringUtils.isNotBlank(itemEditReq.getBrief())) {
item.setBrief(itemEditReq.getBrief());
} else {
item.setBrief(StringUtils.substring(itemEditReq.getRemark(), 0, 200));
}
item.setRemark(itemEditReq.getRemark());
if (null == itemEditReq.getId() || itemEditReq.getId() < 1) {
uuid = UUID.randomUUID().toString().replace("-", "");
item.setUuid(uuid);
item.setKbId(itemEditReq.getKbId());
item.setKbUuid(itemEditReq.getKbUuid());
baseMapper.insert(item);
} else {
item.setId(itemEditReq.getId());
baseMapper.updateById(item);
}
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBaseItem::getUuid, uuid)
.one();
}
public KnowledgeBaseItem getEnable(String uuid) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBaseItem::getUuid, uuid)
.eq(KnowledgeBaseItem::getIsDeleted, false)
.one();
}
public Page<KnowledgeBaseItem> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryChainWrapper<KnowledgeBaseItem> wrapper = ChainWrappers.lambdaQueryChain(baseMapper);
wrapper.select(KnowledgeBaseItem::getId, KnowledgeBaseItem::getUuid, KnowledgeBaseItem::getTitle, KnowledgeBaseItem::getBrief, KnowledgeBaseItem::getKbUuid, KnowledgeBaseItem::getIsEmbedded, KnowledgeBaseItem::getCreateTime, KnowledgeBaseItem::getUpdateTime);
wrapper.eq(KnowledgeBaseItem::getIsDeleted, false);
wrapper.eq(KnowledgeBaseItem::getKbUuid, kbUuid);
if (StringUtils.isNotBlank(keyword)) {
wrapper.eq(KnowledgeBaseItem::getTitle, keyword);
}
return wrapper.page(new Page<>(currentPage, pageSize));
}
public boolean checkAndEmbedding(String[] uuids) {
if (ArrayUtils.isEmpty(uuids)) {
return false;
}
for (String uuid : uuids) {
checkAndEmbedding(uuid);
}
return true;
}
public boolean checkAndEmbedding(String uuid) {
if (checkPrivilege(uuid)) {
KnowledgeBaseItem one = getEnable(uuid);
return embedding(one);
}
return false;
}
public boolean embedding(KnowledgeBaseItem one) {
Metadata metadata = new Metadata();
metadata.add("kb_uuid", one.getKbUuid());
metadata.add("kb_item_uuid", one.getUuid());
Document document = new Document(one.getRemark(), metadata);
embeddingHelper.getEmbeddingStoreIngestor().ingest(document);
return true;
}
@Transactional
public boolean softDelete(String uuid) {
boolean privilege = checkPrivilege(uuid);
if (!privilege) throw new BaseException(A_USER_NOT_AUTH);
boolean success = ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(KnowledgeBaseItem::getUuid, uuid)
.set(KnowledgeBaseItem::getIsDeleted, true)
.update();
if (!success) {
return false;
}
knowledgeBaseEmbeddingService.deleteByItemUuid(uuid);
return true;
}
private boolean checkPrivilege(String uuid) {
if (StringUtils.isBlank(uuid)) {
throw new BaseException(A_PARAMS_ERROR);
}
User user = ThreadContext.getCurrentUser();
if (null == user) {
throw new BaseException(A_USER_NOT_EXIST);
}
if (user.getIsAdmin()) {
return true;
}
Optional<KnowledgeBaseItem> kbItem = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBaseItem::getUuid, uuid)
.oneOpt();
if (kbItem.isPresent()) {
KnowledgeBase kb = knowledgeBaseService.getById(kbItem.get().getKbId());
if (null != kb) {
return kb.getOwnerId().equals(user.getId());
}
}
return false;
}
}

View File

@ -0,0 +1,52 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.entity.KnowledgeBaseQaRecord;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import static com.moyz.adi.common.enums.ErrorEnum.A_DATA_NOT_FOUND;
@Slf4j
@Service
public class KnowledgeBaseQaRecordService extends ServiceImpl<KnowledgeBaseQaRecordMapper, KnowledgeBaseQaRecord> {
public Page<KnowledgeBaseQaRecord> search(String kbUuid, String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryWrapper<KnowledgeBaseQaRecord> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(KnowledgeBaseQaRecord::getKbUuid, kbUuid);
if (!ThreadContext.getCurrentUser().getIsAdmin()) {
wrapper.eq(KnowledgeBaseQaRecord::getUserId, ThreadContext.getCurrentUserId());
}
if (StringUtils.isNotBlank(keyword)) {
wrapper.like(KnowledgeBaseQaRecord::getQuestion, keyword);
}
wrapper.orderByDesc(KnowledgeBaseQaRecord::getUpdateTime);
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
}
public boolean softDelele(String uuid) {
if (ThreadContext.getCurrentUser().getIsAdmin()) {
return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(KnowledgeBaseQaRecord::getUuid, uuid)
.set(KnowledgeBaseQaRecord::getIsDeleted, true)
.update();
}
KnowledgeBaseQaRecord exist = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBaseQaRecord::getUuid, uuid)
.one();
if (null == exist) {
throw new BaseException(A_DATA_NOT_FOUND);
}
return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(KnowledgeBaseQaRecord::getId, exist.getId())
.set(KnowledgeBaseQaRecord::getIsDeleted, true)
.update();
}
}

View File

@ -0,0 +1,249 @@
package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.KbEditReq;
import com.moyz.adi.common.entity.*;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.EmbeddingHelper;
import com.moyz.adi.common.mapper.KnowledgeBaseMapper;
import com.moyz.adi.common.util.BizPager;
import com.moyz.adi.common.util.LocalDateTimeUtil;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.text.MessageFormat;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import static com.moyz.adi.common.cosntant.AdiConstant.POI_DOC_TYPES;
import static com.moyz.adi.common.enums.ErrorEnum.*;
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
@Slf4j
@Service
public class KnowledgeBaseService extends ServiceImpl<KnowledgeBaseMapper, KnowledgeBase> {
@Resource
private StringRedisTemplate stringRedisTemplate;
@Resource
private EmbeddingHelper embeddingHelper;
@Resource
private KnowledgeBaseItemService knowledgeBaseItemService;
@Resource
private KnowledgeBaseQaRecordService knowledgeBaseQaRecordService;
@Resource
private FileService fileService;
public KnowledgeBase saveOrUpdate(KbEditReq kbEditReq) {
String uuid = kbEditReq.getUuid();
KnowledgeBase knowledgeBase = new KnowledgeBase();
knowledgeBase.setTitle(kbEditReq.getTitle());
knowledgeBase.setRemark(kbEditReq.getRemark());
if (null != kbEditReq.getIsPublic()) {
knowledgeBase.setIsPublic(kbEditReq.getIsPublic());
}
if (null == kbEditReq.getId() || kbEditReq.getId() < 1) {
User user = ThreadContext.getCurrentUser();
uuid = UUID.randomUUID().toString().replace("-", "");
knowledgeBase.setUuid(uuid);
knowledgeBase.setOwnerId(user.getId());
knowledgeBase.setOwnerName(user.getName());
baseMapper.insert(knowledgeBase);
} else {
knowledgeBase.setId(kbEditReq.getId());
baseMapper.updateById(knowledgeBase);
}
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBase::getUuid, uuid)
.one();
}
public List<AdiFile> uploadDocs(String kbUuid, Boolean embedding, MultipartFile[] docs) {
if (ArrayUtils.isEmpty(docs)) {
return Collections.emptyList();
}
List<AdiFile> result = new ArrayList<>();
KnowledgeBase knowledgeBase = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBase::getUuid, kbUuid)
.eq(KnowledgeBase::getIsDeleted, false)
.oneOpt()
.orElseThrow(() -> new BaseException(A_DATA_NOT_FOUND));
for (MultipartFile doc : docs) {
try {
result.add(uploadDoc(knowledgeBase, doc, embedding));
} catch (Exception e) {
log.warn("uploadDocs fail,fileName:{}", doc.getOriginalFilename(), e);
}
}
return result;
}
public AdiFile uploadDoc(String kbUuid, Boolean embedding, MultipartFile doc) {
KnowledgeBase knowledgeBase = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBase::getUuid, kbUuid)
.eq(KnowledgeBase::getIsDeleted, false)
.oneOpt()
.orElseThrow(() -> new BaseException(A_DATA_NOT_FOUND));
return uploadDoc(knowledgeBase, doc, embedding);
}
private AdiFile uploadDoc(KnowledgeBase knowledgeBase, MultipartFile doc, Boolean embedding) {
try {
String fileName = doc.getOriginalFilename();
AdiFile adiFile = fileService.writeToLocal(doc);
//解析文档
Document document;
if (adiFile.getExt().equalsIgnoreCase("txt")) {
document = loadDocument(adiFile.getPath(), new TextDocumentParser());
} else if (adiFile.getExt().equalsIgnoreCase("pdf")) {
document = loadDocument(adiFile.getPath(), new ApachePdfBoxDocumentParser());
} else if (ArrayUtils.contains(POI_DOC_TYPES, adiFile.getExt())) {
document = loadDocument(adiFile.getPath(), new ApachePoiDocumentParser());
} else {
log.warn("该文件类型:{}无法解析,忽略", adiFile.getExt());
return adiFile;
}
//创建知识库条目
String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseItem knowledgeBaseItem = new KnowledgeBaseItem();
knowledgeBaseItem.setUuid(uuid);
knowledgeBaseItem.setKbId(knowledgeBase.getId());
knowledgeBaseItem.setKbUuid(knowledgeBase.getUuid());
knowledgeBaseItem.setSourceFileId(adiFile.getId());
knowledgeBaseItem.setTitle(fileName);
knowledgeBaseItem.setBrief(StringUtils.substring(document.text(), 0, 200));
knowledgeBaseItem.setRemark(document.text());
knowledgeBaseItem.setIsEmbedded(true);
boolean success = knowledgeBaseItemService.save(knowledgeBaseItem);
if (success && Boolean.TRUE.equals(embedding)) {
knowledgeBaseItem = knowledgeBaseItemService.getEnable(uuid);
//向量化
Document docWithoutPath = new Document(document.text());
docWithoutPath.metadata()
.add("kb_uuid", knowledgeBase.getUuid())
.add("kb_item_uuid", knowledgeBaseItem.getUuid());
embeddingHelper.getEmbeddingStoreIngestor().ingest(docWithoutPath);
knowledgeBaseItemService
.lambdaUpdate()
.eq(KnowledgeBaseItem::getId, knowledgeBaseItem.getId())
.set(KnowledgeBaseItem::getIsEmbedded, true)
.update();
}
return adiFile;
} catch (Exception e) {
log.error("upload error", e);
throw new BaseException(A_UPLOAD_FAIL);
}
}
public boolean embedding(String kbUuid, boolean forceAll) {
boolean privilege = checkPrivilege(null, kbUuid);
if (!privilege) throw new BaseException(A_USER_NOT_AUTH);
LambdaQueryWrapper<KnowledgeBaseItem> wrapper = new LambdaQueryWrapper();
wrapper.eq(KnowledgeBaseItem::getIsDeleted, false);
wrapper.eq(KnowledgeBaseItem::getUuid, kbUuid);
BizPager.oneByOneWithAnchor(wrapper, knowledgeBaseItemService, KnowledgeBaseItem::getId, one -> {
if (forceAll || !one.getIsEmbedded()) {
knowledgeBaseItemService.embedding(one);
}
});
return true;
}
public Page<KnowledgeBase> search(String keyword, Boolean includeOthersPublic, Integer currentPage, Integer pageSize) {
User user = ThreadContext.getCurrentUser();
if (user.getIsAdmin()) {
return baseMapper.searchByAdmin(new Page<>(currentPage, pageSize), keyword);
} else {
return baseMapper.searchByUser(new Page<>(currentPage, pageSize), user.getId(), keyword, includeOthersPublic);
}
}
public boolean softDelete(String uuid) {
boolean privs = checkPrivilege(null, uuid);
if (!privs) throw new BaseException(A_USER_NOT_AUTH);
return ChainWrappers.lambdaUpdateChain(baseMapper)
.eq(KnowledgeBase::getUuid, uuid)
.set(KnowledgeBase::getIsDeleted, true)
.update();
}
public KnowledgeBaseQaRecord answerAndRecord(String kbUuid, String question) {
String key = MessageFormat.format(RedisKeyConstant.AQ_ASK_TIMES, ThreadContext.getCurrentUserId(), LocalDateTimeUtil.format(LocalDateTime.now(), "yyyyMMdd"));
String askTimes = stringRedisTemplate.opsForValue().get(key);
String askQuota = SysConfigService.getByKey("quota_by_qa_ask_daily");
if (null != askTimes && null != askQuota && Integer.parseInt(askTimes) >= Integer.parseInt(askQuota)) {
throw new BaseException(A_QA_ASK_LIMIT);
}
stringRedisTemplate.opsForValue().increment(key);
KnowledgeBase knowledgeBase = getOrThrow(kbUuid);
String answer = embeddingHelper.findAnswer(kbUuid, question);
String uuid = UUID.randomUUID().toString().replace("-", "");
KnowledgeBaseQaRecord newObj = new KnowledgeBaseQaRecord();
newObj.setKbId(knowledgeBase.getId());
newObj.setKbUuid((knowledgeBase.getUuid()));
newObj.setUuid(uuid);
newObj.setUserId(ThreadContext.getCurrentUserId());
newObj.setQuestion(question);
newObj.setAnswer(answer);
knowledgeBaseQaRecordService.save(newObj);
return knowledgeBaseQaRecordService.lambdaQuery().eq(KnowledgeBaseQaRecord::getUuid, uuid).one();
}
public KnowledgeBase getOrThrow(String kbUuid) {
return ChainWrappers.lambdaQueryChain(baseMapper)
.eq(KnowledgeBase::getUuid, kbUuid)
.eq(KnowledgeBase::getIsDeleted, false)
.oneOpt().orElseThrow(() -> new BaseException(A_DATA_NOT_FOUND));
}
private boolean checkPrivilege(Long kbId, String kbUuid) {
if (null == kbId && StringUtils.isBlank(kbUuid)) {
throw new BaseException(A_PARAMS_ERROR);
}
User user = ThreadContext.getCurrentUser();
if (null == user) {
throw new BaseException(A_USER_NOT_EXIST);
}
boolean privilege = user.getIsAdmin();
if (privilege) {
return true;
}
LambdaQueryWrapper<KnowledgeBase> wrapper = new LambdaQueryWrapper();
wrapper.eq(KnowledgeBase::getOwnerId, user.getId());
if (null != kbId) {
wrapper = wrapper.eq(KnowledgeBase::getId, kbId);
} else if (StringUtils.isNotBlank(kbUuid)) {
wrapper = wrapper.eq(KnowledgeBase::getUuid, kbUuid);
}
return baseMapper.exists(wrapper);
}
}

View File

@ -25,7 +25,7 @@ import java.util.Map;
public class PromptService extends ServiceImpl<PromptMapper, Prompt> { public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
public List<PromptDto> getAll(long userId) { public List<PromptDto> getAll(long userId) {
List<Prompt> prompts = this.lambdaQuery().eq(Prompt::getUserId, userId).eq(Prompt::getIsDelete, false).list(); List<Prompt> prompts = this.lambdaQuery().eq(Prompt::getUserId, userId).eq(Prompt::getIsDeleted, false).list();
return MPPageUtil.convertTo(prompts, PromptDto.class); return MPPageUtil.convertTo(prompts, PromptDto.class);
} }
@ -34,13 +34,13 @@ public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
if (StringUtils.isNotBlank(keyword)) { if (StringUtils.isNotBlank(keyword)) {
promptPage = this.lambdaQuery() promptPage = this.lambdaQuery()
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.like(Prompt::getAct, keyword) .like(Prompt::getAct, keyword)
.page(new Page<>(currentPage, pageSize)); .page(new Page<>(currentPage, pageSize));
} else { } else {
promptPage = this.lambdaQuery() promptPage = this.lambdaQuery()
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.page(new Page<>(currentPage, pageSize)); .page(new Page<>(currentPage, pageSize));
} }
return MPPageUtil.convertTo(promptPage, new Page<>(), PromptDto.class); return MPPageUtil.convertTo(promptPage, new Page<>(), PromptDto.class);
@ -51,14 +51,14 @@ public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
if (StringUtils.isNotBlank(keyword)) { if (StringUtils.isNotBlank(keyword)) {
promptPage = this.lambdaQuery() promptPage = this.lambdaQuery()
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.like(Prompt::getAct, keyword) .like(Prompt::getAct, keyword)
.last("limit 10") .last("limit 10")
.list(); .list();
} else { } else {
promptPage = this.lambdaQuery() promptPage = this.lambdaQuery()
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.last("limit 10") .last("limit 10")
.list(); .list();
} }
@ -106,7 +106,7 @@ public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
Prompt existOne = this.lambdaQuery() Prompt existOne = this.lambdaQuery()
.eq(Prompt::getUserId, userId) .eq(Prompt::getUserId, userId)
.eq(Prompt::getAct, title) .eq(Prompt::getAct, title)
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.one(); .one();
if (null != existOne) { if (null != existOne) {
//modify //modify
@ -137,14 +137,14 @@ public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
Prompt prompt = this.lambdaQuery() Prompt prompt = this.lambdaQuery()
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getId, id) .eq(Prompt::getId, id)
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.one(); .one();
if (null == prompt) { if (null == prompt) {
return false; return false;
} }
Prompt updateOne = new Prompt(); Prompt updateOne = new Prompt();
updateOne.setId(id); updateOne.setId(id);
updateOne.setIsDelete(true); updateOne.setIsDeleted(true);
return this.updateById(updateOne); return this.updateById(updateOne);
} }
@ -152,7 +152,7 @@ public class PromptService extends ServiceImpl<PromptMapper, Prompt> {
Prompt prompt = this.lambdaQuery() Prompt prompt = this.lambdaQuery()
.eq(Prompt::getId, id) .eq(Prompt::getId, id)
.eq(Prompt::getUserId, ThreadContext.getCurrentUserId()) .eq(Prompt::getUserId, ThreadContext.getCurrentUserId())
.eq(Prompt::getIsDelete, false) .eq(Prompt::getIsDeleted, false)
.one(); .one();
if (null == prompt) { if (null == prompt) {
return false; return false;

View File

@ -1,13 +1,16 @@
package com.moyz.adi.common.service; package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.moyz.adi.common.cosntant.AdiConstant; import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.model.RequestRateLimit;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.entity.SysConfig; import com.moyz.adi.common.entity.SysConfig;
import com.moyz.adi.common.mapper.SysConfigMapper; import com.moyz.adi.common.mapper.SysConfigMapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.vo.RequestRateLimit;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.Scheduled; import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -21,7 +24,7 @@ public class SysConfigService extends ServiceImpl<SysConfigMapper, SysConfig> {
@Scheduled(fixedDelay = 20 * 60 * 1000) @Scheduled(fixedDelay = 20 * 60 * 1000)
public void reload() { public void reload() {
log.info("reload system config"); log.info("reload system config");
List<SysConfig> configsFromDB = this.lambdaQuery().eq(SysConfig::getIsDelete, false).list(); List<SysConfig> configsFromDB = this.lambdaQuery().eq(SysConfig::getIsDeleted, false).list();
if (LocalCache.CONFIGS.isEmpty()) { if (LocalCache.CONFIGS.isEmpty()) {
configsFromDB.stream().forEach(item -> LocalCache.CONFIGS.put(item.getName(), item.getValue())); configsFromDB.stream().forEach(item -> LocalCache.CONFIGS.put(item.getName(), item.getValue()));
} else { } else {
@ -58,4 +61,24 @@ public class SysConfigService extends ServiceImpl<SysConfigMapper, SysConfig> {
return LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.SECRET_KEY); return LocalCache.CONFIGS.get(AdiConstant.SysConfigKey.SECRET_KEY);
} }
public static String getByKey(String key) {
return LocalCache.CONFIGS.get(key);
}
public static Integer getIntByKey(String key) {
String val = LocalCache.CONFIGS.get(key);
if (null != val) {
return Integer.parseInt(val);
}
return null;
}
public Page<SysConfig> search(String keyword, Integer currentPage, Integer pageSize) {
LambdaQueryWrapper<SysConfig> wrapper = new LambdaQueryWrapper<>();
if (StringUtils.isNotBlank(keyword)) {
wrapper.eq(SysConfig::getName, keyword);
}
return baseMapper.selectPage(new Page<>(currentPage, pageSize), wrapper);
}
} }

View File

@ -5,7 +5,7 @@ import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.util.LocalDateTimeUtil; import com.moyz.adi.common.util.LocalDateTimeUtil;
import com.moyz.adi.common.entity.UserDayCost; import com.moyz.adi.common.entity.UserDayCost;
import com.moyz.adi.common.mapper.UserDayCostMapper; import com.moyz.adi.common.mapper.UserDayCostMapper;
import com.moyz.adi.common.model.CostStat; import com.moyz.adi.common.vo.CostStat;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.moyz.adi.common.util.UserUtil; import com.moyz.adi.common.util.UserUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -3,22 +3,22 @@ package com.moyz.adi.common.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers; import com.baomidou.mybatisplus.extension.toolkit.ChainWrappers;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.enums.UserStatusEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.base.ThreadContext; import com.moyz.adi.common.base.ThreadContext;
import com.moyz.adi.common.cosntant.AdiConstant;
import com.moyz.adi.common.cosntant.RedisKeyConstant; import com.moyz.adi.common.cosntant.RedisKeyConstant;
import com.moyz.adi.common.dto.ConfigResp; import com.moyz.adi.common.dto.ConfigResp;
import com.moyz.adi.common.dto.LoginReq; import com.moyz.adi.common.dto.LoginReq;
import com.moyz.adi.common.dto.LoginResp; import com.moyz.adi.common.dto.LoginResp;
import com.moyz.adi.common.dto.UserUpdateReq; import com.moyz.adi.common.dto.UserUpdateReq;
import com.moyz.adi.common.entity.User; import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.enums.ErrorEnum;
import com.moyz.adi.common.enums.UserStatusEnum;
import com.moyz.adi.common.exception.BaseException;
import com.moyz.adi.common.helper.AdiMailSender; import com.moyz.adi.common.helper.AdiMailSender;
import com.moyz.adi.common.mapper.UserMapper; import com.moyz.adi.common.mapper.UserMapper;
import com.moyz.adi.common.model.CostStat; import com.moyz.adi.common.util.JsonUtil;
import com.moyz.adi.common.util.LocalCache;
import com.moyz.adi.common.vo.CostStat;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -71,7 +71,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
} }
return this.lambdaQuery() return this.lambdaQuery()
.eq(User::getEmail, email) .eq(User::getEmail, email)
.eq(User::getIsDelete, false) .eq(User::getIsDeleted, false)
.oneOpt() .oneOpt()
.orElseThrow(() -> new BaseException(A_USER_NOT_EXIST)); .orElseThrow(() -> new BaseException(A_USER_NOT_EXIST));
} }
@ -94,7 +94,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
stringRedisTemplate.delete(captchaInCache); stringRedisTemplate.delete(captchaInCache);
User user = ChainWrappers.lambdaQueryChain(baseMapper) User user = ChainWrappers.lambdaQueryChain(baseMapper)
.eq(User::getIsDelete, false) .eq(User::getIsDeleted, false)
.eq(User::getEmail, email) .eq(User::getEmail, email)
.one(); .one();
if (null != user && user.getUserStatus() == UserStatusEnum.NORMAL) { if (null != user && user.getUserStatus() == UserStatusEnum.NORMAL) {
@ -112,7 +112,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
//创建用户 //创建用户
User newOne = new User(); User newOne = new User();
newOne.setName(email.substring(0, email.indexOf("@"))); newOne.setName(StringUtils.substringBetween(email, "@"));
newOne.setUuid(UUID.randomUUID().toString().replace("-", "")); newOne.setUuid(UUID.randomUUID().toString().replace("-", ""));
newOne.setEmail(email); newOne.setEmail(email);
newOne.setPassword(hashed); newOne.setPassword(hashed);
@ -157,7 +157,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
LambdaQueryWrapper<User> queryWrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<User> queryWrapper = new LambdaQueryWrapper<>();
User user = this.lambdaQuery() User user = this.lambdaQuery()
.eq(User::getEmail, email) .eq(User::getEmail, email)
.eq(User::getIsDelete, false) .eq(User::getIsDeleted, false)
.oneOpt() .oneOpt()
.orElse(null); .orElse(null);
if (null == user) { if (null == user) {
@ -201,7 +201,7 @@ public class UserService extends ServiceImpl<UserMapper, User> {
//captcha check end //captcha check end
User user = this.lambdaQuery() User user = this.lambdaQuery()
.eq(User::getIsDelete, false) .eq(User::getIsDeleted, false)
.eq(User::getEmail, loginReq.getEmail()) .eq(User::getEmail, loginReq.getEmail())
.oneOpt() .oneOpt()
.orElseThrow(() -> new BaseException(ErrorEnum.A_USER_NOT_EXIST)); .orElseThrow(() -> new BaseException(ErrorEnum.A_USER_NOT_EXIST));

View File

@ -0,0 +1,341 @@
package com.moyz.adi.common.util;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.pgvector.PGvector;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import lombok.Builder;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Type;
import java.sql.*;
import java.util.*;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.*;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* 复制并做了少许改动()
* PGVector EmbeddingStore Implementation
* <p>
* Only cosine similarity is used.
* Only ivfflat index is used.
*/
public class AdiPgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final Logger log = LoggerFactory.getLogger(AdiPgVectorEmbeddingStore.class);
private static final Gson GSON = new Gson();
private final String host;
private final Integer port;
private final String user;
private final String password;
private final String database;
private final String table;
/**
* All args constructor for PgVectorEmbeddingStore Class
*
* @param host The database host
* @param port The database port
* @param user The database user
* @param password The database password
* @param database The database name
* @param table The database table
* @param dimension The vector dimension
* @param useIndex Should use <a href="https://github.com/pgvector/pgvector#ivfflat">IVFFlat</a> index
* @param indexListSize The IVFFlat number of lists
* @param createTable Should create table automatically
* @param dropTableFirst Should drop table first, usually for testing
*/
@Builder
public AdiPgVectorEmbeddingStore(
String host,
Integer port,
String user,
String password,
String database,
String table,
Integer dimension,
Boolean useIndex,
Integer indexListSize,
Boolean createTable,
Boolean dropTableFirst) {
this.host = ensureNotBlank(host, "host");
this.port = ensureGreaterThanZero(port, "port");
this.user = ensureNotBlank(user, "user");
this.password = ensureNotBlank(password, "password");
this.database = ensureNotBlank(database, "database");
this.table = ensureNotBlank(table, "table");
useIndex = getOrDefault(useIndex, false);
createTable = getOrDefault(createTable, true);
dropTableFirst = getOrDefault(dropTableFirst, false);
try (Connection connection = setupConnection()) {
if (dropTableFirst) {
connection.createStatement().executeUpdate(String.format("DROP TABLE IF EXISTS %s", table));
}
if (createTable) {
connection.createStatement().executeUpdate(String.format(
"CREATE TABLE IF NOT EXISTS %s (" +
"embedding_id UUID PRIMARY KEY, " +
"embedding vector(%s), " +
"text TEXT NULL, " +
"metadata JSON NULL" +
")",
table, ensureGreaterThanZero(dimension, "dimension")));
}
if (useIndex) {
final String indexName = table + "_ivfflat_index";
connection.createStatement().executeUpdate(String.format(
"CREATE INDEX IF NOT EXISTS %s ON %s " +
"USING ivfflat (embedding vector_cosine_ops) " +
"WITH (lists = %s)",
indexName, table, ensureGreaterThanZero(indexListSize, "indexListSize")));
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
private Connection setupConnection() throws SQLException {
Connection connection = DriverManager.getConnection(
String.format("jdbc:postgresql://%s:%s/%s", host, port, database),
user,
password
);
connection.createStatement().executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
PGvector.addVectorType(connection);
return connection;
}
/**
* Adds a given embedding to the store.
*
* @param embedding The embedding to be added to the store.
* @return The auto-generated ID associated with the added embedding.
*/
@Override
public String add(Embedding embedding) {
String id = randomUUID();
addInternal(id, embedding, null);
return id;
}
/**
* Adds a given embedding to the store.
*
* @param id The unique identifier for the embedding to be added.
* @param embedding The embedding to be added to the store.
*/
@Override
public void add(String id, Embedding embedding) {
addInternal(id, embedding, null);
}
/**
* Adds a given embedding and the corresponding content that has been embedded to the store.
*
* @param embedding The embedding to be added to the store.
* @param textSegment Original content that was embedded.
* @return The auto-generated ID associated with the added embedding.
*/
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}
/**
* Adds multiple embeddings to the store.
*
* @param embeddings A list of embeddings to be added to the store.
* @return A list of auto-generated IDs associated with the added embeddings.
*/
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, null);
return ids;
}
/**
* Adds multiple embeddings and their corresponding contents that have been embedded to the store.
*
* @param embeddings A list of embeddings to be added to the store.
* @param embedded A list of original contents that were embedded.
* @return A list of auto-generated IDs associated with the added embeddings.
*/
@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
List<String> ids = embeddings.stream().map(ignored -> randomUUID()).collect(toList());
addAllInternal(ids, embeddings, embedded);
return ids;
}
/**
* Finds the most relevant (closest in space) embeddings to the provided reference embedding.
*
* @param referenceEmbedding The embedding used as a reference. Returned embeddings should be relevant (closest) to this one.
* @param maxResults The maximum number of embeddings to be returned.
* @param minScore The minimum relevance score, ranging from 0 to 1 (inclusive).
* Only embeddings with a score of this value or higher will be returned.
* @return A list of embedding matches.
* Each embedding match includes a relevance score (derivative of cosine distance),
* ranging from 0 (not relevant) to 1 (highly relevant).
*/
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
try (Connection connection = setupConnection()) {
String referenceVector = Arrays.toString(referenceEmbedding.vector());
String query = String.format(
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
referenceVector, table, minScore, maxResults);
PreparedStatement selectStmt = connection.prepareStatement(query);
ResultSet resultSet = selectStmt.executeQuery();
while (resultSet.next()) {
double score = resultSet.getDouble("score");
String embeddingId = resultSet.getString("embedding_id");
PGvector vector = (PGvector) resultSet.getObject("embedding");
Embedding embedding = new Embedding(vector.toArray());
String text = resultSet.getString("text");
TextSegment textSegment = null;
if (isNotNullOrBlank(text)) {
String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}");
Type type = new TypeToken<Map<String, String>>() {
}.getType();
Metadata metadata = new Metadata(new HashMap<>(GSON.fromJson(metadataJson, type)));
textSegment = TextSegment.from(text, metadata);
}
result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment));
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
return result;
}
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAllInternal(
singletonList(id),
singletonList(embedding),
embedded == null ? null : singletonList(embedded));
}
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("Empty embeddings - no ops");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(),
"embeddings size is not equal to embedded size");
try (Connection connection = setupConnection()) {
String query = String.format(
"INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)" +
"ON CONFLICT (embedding_id) DO UPDATE SET " +
"embedding = EXCLUDED.embedding," +
"text = EXCLUDED.text," +
"metadata = EXCLUDED.metadata;",
table);
PreparedStatement upsertStmt = connection.prepareStatement(query);
for (int i = 0; i < ids.size(); ++i) {
upsertStmt.setObject(1, UUID.fromString(ids.get(i)));
upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector()));
if (embedded != null && embedded.get(i) != null) {
upsertStmt.setObject(3, embedded.get(i).text());
Map<String, String> metadata = embedded.get(i).metadata().asMap();
upsertStmt.setObject(4, GSON.toJson(metadata), Types.OTHER);
} else {
upsertStmt.setNull(3, Types.VARCHAR);
upsertStmt.setNull(4, Types.OTHER);
}
upsertStmt.addBatch();
}
upsertStmt.executeBatch();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
//adi
public List<EmbeddingMatch<TextSegment>> findRelevantByKbUuid(String kbUuid, Embedding referenceEmbedding, int maxResults, double minScore) {
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>();
try (Connection connection = setupConnection()) {
String referenceVector = Arrays.toString(referenceEmbedding.vector());
//新增查询条件kb_id
String query = String.format(
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s where metadata->>'kb_uuid' = '%s') SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
referenceVector, table, kbUuid, minScore, maxResults);
PreparedStatement selectStmt = connection.prepareStatement(query);
ResultSet resultSet = selectStmt.executeQuery();
while (resultSet.next()) {
double score = resultSet.getDouble("score");
String embeddingId = resultSet.getString("embedding_id");
PGvector vector = (PGvector) resultSet.getObject("embedding");
Embedding embedding = new Embedding(vector.toArray());
String text = resultSet.getString("text");
TextSegment textSegment = null;
if (isNotNullOrBlank(text)) {
String metadataJson = Optional.ofNullable(resultSet.getString("metadata")).orElse("{}");
Type type = new TypeToken<Map<String, String>>() {
}.getType();
Metadata metadata = new Metadata(new HashMap<>(GSON.fromJson(metadataJson, type)));
textSegment = TextSegment.from(text, metadata);
}
result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment));
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
return result;
}
public int deleteByMetadata(String metadataKey, String metadataValue) {
if (StringUtils.isAnyBlank(metadataKey, metadataValue)) {
return NumberUtils.INTEGER_ZERO;
}
try (Connection connection = setupConnection()) {
String query = String.format("delete from %s where metadata->'%s'=?", table, metadataKey);
PreparedStatement prepareStatement = connection.prepareStatement(query);
prepareStatement.setString(1, metadataValue);
return prepareStatement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,6 +1,6 @@
package com.moyz.adi.common.util; package com.moyz.adi.common.util;
import com.moyz.adi.common.model.RequestRateLimit; import com.moyz.adi.common.vo.RequestRateLimit;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;

View File

@ -9,11 +9,16 @@ import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiFunction;
@Slf4j @Slf4j
public class MPPageUtil { public class MPPageUtil {
public static <T, U> Page<U> convertTo(Page<T> source, Page<U> target, Class<U> targetRecordClass) { public static <T, U> Page<U> convertTo(Page<T> source, Page<U> target, Class<U> targetRecordClass) {
return MPPageUtil.convertTo(source, target, targetRecordClass, null);
}
public static <T, U> Page<U> convertTo(Page<T> source, Page<U> target, Class<U> targetRecordClass, BiFunction<T, U, U> biFunction) {
BeanUtils.copyProperties(source, target); BeanUtils.copyProperties(source, target);
List<U> records = new ArrayList<>(); List<U> records = new ArrayList<>();
target.setRecords(records); target.setRecords(records);
@ -21,6 +26,9 @@ public class MPPageUtil {
for (T t : source.getRecords()) { for (T t : source.getRecords()) {
U u = targetRecordClass.getDeclaredConstructor().newInstance(); U u = targetRecordClass.getDeclaredConstructor().newInstance();
BeanUtils.copyProperties(t, u); BeanUtils.copyProperties(t, u);
if (null != biFunction) {
biFunction.apply(t, u);
}
records.add(u); records.add(u);
} }
} catch (NoSuchMethodException e1) { } catch (NoSuchMethodException e1) {

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class AnswerMeta {
private Integer tokens;
private String uuid;
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ChatMeta {
private QuestionMeta question;
private AnswerMeta answer;
}

View File

@ -0,0 +1,14 @@
package com.moyz.adi.common.vo;
import lombok.Data;
@Data
public class CostStat {
private int day;
private int textRequestTimesByDay;
private int textTokenCostByDay;
private int imageGeneratedNumberByDay;
private int textTokenCostByMonth;
private int textRequestTimesByMonth;
private int imageGeneratedNumberByMonth;
}

View File

@ -0,0 +1,11 @@
package com.moyz.adi.common.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class QuestionMeta {
private Integer tokens;
private String uuid;
}

View File

@ -0,0 +1,13 @@
package com.moyz.adi.common.vo;
import lombok.Data;
@Data
public class RequestRateLimit {
private int times;
private int minutes;
private int type;
public static final int TYPE_TEXT = 1;
public static final int TYPE_IMAGE = 2;
}

View File

@ -0,0 +1,25 @@
package com.moyz.adi.common.vo;
import com.moyz.adi.common.entity.User;
import com.moyz.adi.common.util.TriConsumer;
import dev.langchain4j.memory.ChatMemory;
import lombok.Data;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Data
public class SseAskParams {
private User user;
private String regenerateQuestionUuid;
private String systemMessage;
private ChatMemory chatMemory;
private String userMessage;
private SseEmitter sseEmitter;
}

View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.moyz.adi.common.mapper.KnowledgeBaseEmbeddingMapper">
<select id="selectByItemUuid" resultType="com.moyz.adi.common.entity.KnowledgeBaseEmbedding">
select * from adi_knowledge_base_embedding where metadata->>'kb_item_uuid' = #{kbItemUuid}
</select>
<delete id="deleteByItemUuid">
delete from adi_knowledge_base_embedding where metadata->>'kb_item_uuid' = #{kbItemUuid}
</delete>
</mapper>

View File

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.moyz.adi.common.mapper.KnowledgeBaseMapper">
<select id="searchByAdmin" resultType="com.moyz.adi.common.entity.KnowledgeBase">
select *
from adi_knowledge_base
where is_deleted = false
<if test="keyword != null and keyword != ''">
and title like "%"#{keyword}"%"
</if>
order by update_time desc
</select>
<select id="searchByUser" resultType="com.moyz.adi.common.entity.KnowledgeBase">
select *
from adi_knowledge_base
where is_deleted = false
<choose>
<when test="includeOthersPublic">
nd (is_public = true or owner_id = #{ownerId})
</when>
<otherwise>
and owner_id = #{ownerId}
</otherwise>
</choose>
<if test="keyword != null and keyword != ''">
and title like "%"#{keyword}"%"
</if>
order by update_time desc
</select>
</mapper>

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.moyz.adi.common.mapper.KnowledgeBaseQaRecordMapper">
</mapper>

View File

@ -1,186 +1,528 @@
CREATE TABLE `adi_ai_model` -- 需要先安装pgvector这个扩展https://github.com/pgvector/pgvector
( -- CREATE EXTENSION vector;
`id` bigint NOT NULL AUTO_INCREMENT,
`name` varchar(45) NOT NULL DEFAULT '',
`remark` varchar(1000) DEFAULT NULL,
`model_status` tinyint NOT NULL DEFAULT '1' COMMENT '1:正常使用,2:不可用',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='ai模型';
CREATE TABLE `adi_sys_config` SET client_encoding = 'UTF8';
( CREATE SCHEMA public;
`id` bigint NOT NULL AUTO_INCREMENT,
`name` varchar(100) NOT NULL DEFAULT '',
`value` varchar(100) NOT NULL DEFAULT '',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`is_delete` tinyint NOT NULL DEFAULT '0',
PRIMARY KEY (`id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='系统配置表';
INSERT INTO `adi_sys_config` (`name`, `value`) CREATE TABLE public.adi_ai_image
(
id bigserial primary key,
user_id bigint DEFAULT '0'::bigint NOT NULL,
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
prompt character varying(1024) DEFAULT ''::character varying NOT NULL,
generate_size character varying(20) DEFAULT ''::character varying NOT NULL,
generate_number integer DEFAULT 1 NOT NULL,
original_image character varying(1000) DEFAULT ''::character varying NOT NULL,
mask_image character varying(1000) DEFAULT ''::character varying NOT NULL,
resp_images_path character varying(2048) DEFAULT ''::character varying NOT NULL,
generated_images character varying(2048) DEFAULT ''::character varying NOT NULL,
interacting_method smallint DEFAULT '1'::smallint NOT NULL,
process_status smallint DEFAULT '1'::smallint NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL,
CONSTRAINT adi_ai_image_generate_number_check CHECK (((generate_number >= 1) AND (generate_number <= 10))),
CONSTRAINT adi_ai_image_interacting_method_check CHECK ((interacting_method = ANY (ARRAY [1, 2, 3]))),
CONSTRAINT adi_ai_image_process_status_check CHECK ((process_status = ANY (ARRAY [1, 2, 3]))),
CONSTRAINT adi_ai_image_user_id_check CHECK ((user_id >= 0))
);
ALTER TABLE ONLY public.adi_ai_image
ADD CONSTRAINT udx_uuid UNIQUE (uuid);
COMMENT ON TABLE public.adi_ai_image IS 'Images generated by ai';
COMMENT ON COLUMN public.adi_ai_image.user_id IS 'The user who generated the image';
COMMENT ON COLUMN public.adi_ai_image.uuid IS 'The uuid of the request of generated images';
COMMENT ON COLUMN public.adi_ai_image.prompt IS 'The prompt for generating images';
COMMENT ON COLUMN public.adi_ai_image.generate_size IS 'The size of the generated images. Must be one of "256x256", "512x512", or "1024x1024"';
COMMENT ON COLUMN public.adi_ai_image.generate_number IS 'The number of images to generate. Must be between 1 and 10. Defaults to 1.';
COMMENT ON COLUMN public.adi_ai_image.original_image IS 'The path of the original image (local path or http path), interacting_method must be 2/3';
COMMENT ON COLUMN public.adi_ai_image.mask_image IS 'The path of the mask image (local path or http path), interacting_method must be 2';
COMMENT ON COLUMN public.adi_ai_image.resp_images_path IS 'The url of the generated images which from openai response, separated by commas';
COMMENT ON COLUMN public.adi_ai_image.generated_images IS 'The path of the generated images, separated by commas';
COMMENT ON COLUMN public.adi_ai_image.interacting_method IS '1: Creating images from scratch based on a text prompt; 2: Creating edits of an existing image based on a new text prompt; 3: Creating variations of an existing image';
COMMENT ON COLUMN public.adi_ai_image.process_status IS 'Generate image status, 1: doing, 2: fail, 3: success';
COMMENT ON COLUMN public.adi_ai_image.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_ai_image.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_ai_image.is_deleted IS 'Flag indicating whether the record is deleted (0: not deleted, 1: deleted)';
CREATE TABLE public.adi_ai_model
(
id bigserial primary key,
name character varying(45) DEFAULT ''::character varying NOT NULL,
remark character varying(1000),
model_status smallint DEFAULT '1'::smallint NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL,
CONSTRAINT adi_ai_model_model_status_check CHECK ((model_status = ANY (ARRAY [1, 2])))
);
COMMENT ON TABLE public.adi_ai_model IS 'ai模型';
COMMENT ON COLUMN public.adi_ai_model.name IS 'The name of the AI model';
COMMENT ON COLUMN public.adi_ai_model.remark IS 'Additional remarks about the AI model';
COMMENT ON COLUMN public.adi_ai_model.model_status IS '1: Normal usage, 2: Not available';
COMMENT ON COLUMN public.adi_ai_model.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_ai_model.update_time IS 'Timestamp of record last update, automatically updated on each update';
CREATE TABLE public.adi_conversation
(
id bigserial primary key,
user_id bigint DEFAULT '0'::bigint NOT NULL,
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
title character varying(45) DEFAULT ''::character varying NOT NULL,
openai_conversation_id character varying(32) DEFAULT ''::character varying NOT NULL,
tokens integer DEFAULT 0 NOT NULL,
ai_system_message character varying(1000) DEFAULT ''::character varying NOT NULL,
ai_model character varying(45) DEFAULT ''::character varying NOT NULL,
understand_context_enable boolean DEFAULT false NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_conversation IS '对话表';
COMMENT ON COLUMN public.adi_conversation.user_id IS '用户id';
COMMENT ON COLUMN public.adi_conversation.ai_model IS '模型名称';
COMMENT ON COLUMN public.adi_conversation.title IS '对话标题';
CREATE TABLE public.adi_conversation_message
(
id bigserial primary key,
parent_message_id bigint DEFAULT '0'::bigint NOT NULL,
conversation_id bigint DEFAULT '0'::bigint NOT NULL,
conversation_uuid character varying(32) DEFAULT ''::character varying NOT NULL,
remark text NOT NULL,
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
message_role integer DEFAULT 1 NOT NULL,
tokens integer DEFAULT 0 NOT NULL,
openai_message_id character varying(32) DEFAULT ''::character varying NOT NULL,
user_id bigint DEFAULT '0'::bigint NOT NULL,
secret_key_type integer DEFAULT 1 NOT NULL,
understand_context_msg_pair_num integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_conversation_message IS '对话消息表';
COMMENT ON COLUMN public.adi_conversation_message.parent_message_id IS '父级消息id';
COMMENT ON COLUMN public.adi_conversation_message.conversation_id IS '对话id';
COMMENT ON COLUMN public.adi_conversation_message.conversation_uuid IS 'conversation''s uuid';
COMMENT ON COLUMN public.adi_conversation_message.remark IS 'ai回复的消息';
COMMENT ON COLUMN public.adi_conversation_message.uuid IS '唯一标识消息的UUID';
COMMENT ON COLUMN public.adi_conversation_message.message_role IS '产生该消息的角色1: 用户, 2: 系统, 3: 助手';
COMMENT ON COLUMN public.adi_conversation_message.tokens IS '消耗的token数量';
COMMENT ON COLUMN public.adi_conversation_message.openai_message_id IS 'OpenAI生成的消息ID';
COMMENT ON COLUMN public.adi_conversation_message.user_id IS '用户ID';
COMMENT ON COLUMN public.adi_conversation_message.secret_key_type IS '加密密钥类型';
COMMENT ON COLUMN public.adi_conversation_message.understand_context_msg_pair_num IS '上下文消息对数量';
CREATE TABLE public.adi_file
(
id bigserial primary key,
name character varying(36) DEFAULT ''::character varying NOT NULL,
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
ext character varying(36) DEFAULT ''::character varying NOT NULL,
user_id bigint DEFAULT '0'::bigint NOT NULL,
path character varying(250) DEFAULT ''::character varying NOT NULL,
ref_count integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL,
md5 character varying(128) DEFAULT ''::character varying NOT NULL
);
COMMENT ON TABLE public.adi_file IS '文件';
COMMENT ON COLUMN public.adi_file.name IS 'File name';
COMMENT ON COLUMN public.adi_file.uuid IS 'UUID of the file';
COMMENT ON COLUMN public.adi_file.ext IS 'File extension';
COMMENT ON COLUMN public.adi_file.user_id IS '0: System; Other: User';
COMMENT ON COLUMN public.adi_file.path IS 'File path';
COMMENT ON COLUMN public.adi_file.ref_count IS 'The number of references to this file';
COMMENT ON COLUMN public.adi_file.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_file.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_file.is_deleted IS '0: Normal; 1: Deleted';
COMMENT ON COLUMN public.adi_file.md5 IS 'MD5 hash of the file';
CREATE TABLE public.adi_prompt
(
id bigserial primary key,
user_id bigint DEFAULT '0'::bigint NOT NULL,
act character varying(120) DEFAULT ''::character varying NOT NULL,
prompt text NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_prompt IS '提示词';
COMMENT ON COLUMN public.adi_prompt.user_id IS '所属用户(0: system)';
COMMENT ON COLUMN public.adi_prompt.act IS '提示词标题';
COMMENT ON COLUMN public.adi_prompt.prompt IS '提示词内容';
COMMENT ON COLUMN public.adi_prompt.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_prompt.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_prompt.is_deleted IS '0:未删除1已删除';
CREATE TABLE public.adi_sys_config
(
id bigserial primary key,
name character varying(100) DEFAULT ''::character varying NOT NULL,
value character varying(100) DEFAULT ''::character varying NOT NULL,
create_time timestamp DEFAULT localtimestamp NOT NULL,
update_time timestamp DEFAULT localtimestamp NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_sys_config IS '系统配置表';
COMMENT ON COLUMN public.adi_sys_config.name IS '配置项名称';
COMMENT ON COLUMN public.adi_sys_config.value IS '配置项值';
COMMENT ON COLUMN public.adi_sys_config.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_sys_config.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_sys_config.is_deleted IS '0未删除1已删除';
CREATE TABLE public.adi_user
(
id bigserial primary key,
name character varying(45) DEFAULT ''::character varying NOT NULL,
password character varying(120) DEFAULT ''::character varying NOT NULL,
uuid character varying(32) DEFAULT ''::character varying NOT NULL,
email character varying(120) DEFAULT ''::character varying NOT NULL,
active_time timestamp,
user_status smallint DEFAULT '1'::smallint NOT NULL,
is_admin boolean DEFAULT false NOT NULL,
quota_by_token_daily integer DEFAULT 0 NOT NULL,
quota_by_token_monthly integer DEFAULT 0 NOT NULL,
quota_by_request_daily integer DEFAULT 0 NOT NULL,
quota_by_request_monthly integer DEFAULT 0 NOT NULL,
secret_key character varying(120) DEFAULT ''::character varying NOT NULL,
understand_context_enable smallint DEFAULT '0'::smallint NOT NULL,
understand_context_msg_pair_num integer DEFAULT 0 NOT NULL,
quota_by_image_daily integer DEFAULT 0 NOT NULL,
quota_by_image_monthly integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_user IS '用户表';
COMMENT ON COLUMN public.adi_user.name IS '用户名';
COMMENT ON COLUMN public.adi_user.password IS '密码';
COMMENT ON COLUMN public.adi_user.uuid IS 'UUID of the user';
COMMENT ON COLUMN public.adi_user.email IS '用户邮箱';
COMMENT ON COLUMN public.adi_user.active_time IS '激活时间';
COMMENT ON COLUMN public.adi_user.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_user.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_user.user_status IS '用户状态1待验证2正常3冻结';
COMMENT ON COLUMN public.adi_user.is_admin IS '是否管理员01';
COMMENT ON COLUMN public.adi_user.is_deleted IS '0未删除1已删除';
COMMENT ON COLUMN public.adi_user.quota_by_token_daily IS '每日token配额';
COMMENT ON COLUMN public.adi_user.quota_by_token_monthly IS '每月token配额';
COMMENT ON COLUMN public.adi_user.quota_by_request_daily IS '每日请求配额';
COMMENT ON COLUMN public.adi_user.quota_by_request_monthly IS '每月请求配额';
COMMENT ON COLUMN public.adi_user.secret_key IS '用户密钥';
COMMENT ON COLUMN public.adi_user.understand_context_enable IS '上下文理解开关';
COMMENT ON COLUMN public.adi_user.understand_context_msg_pair_num IS '上下文消息对数量';
COMMENT ON COLUMN public.adi_user.quota_by_image_daily IS '每日图片配额';
COMMENT ON COLUMN public.adi_user.quota_by_image_monthly IS '每月图片配额';
CREATE TABLE public.adi_user_day_cost
(
id bigserial primary key,
user_id bigint DEFAULT '0'::bigint NOT NULL,
day integer DEFAULT 0 NOT NULL,
requests integer DEFAULT 0 NOT NULL,
tokens integer DEFAULT 0 NOT NULL,
create_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
update_time timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
secret_key_type integer DEFAULT 0 NOT NULL,
images_number integer DEFAULT 0 NOT NULL,
is_deleted boolean DEFAULT false NOT NULL
);
COMMENT ON TABLE public.adi_user_day_cost IS '用户每天消耗总量表';
COMMENT ON COLUMN public.adi_user_day_cost.user_id IS '用户ID';
COMMENT ON COLUMN public.adi_user_day_cost.day IS '日期用7位整数表示如20230901';
COMMENT ON COLUMN public.adi_user_day_cost.requests IS '请求数量';
COMMENT ON COLUMN public.adi_user_day_cost.tokens IS '消耗的token数量';
COMMENT ON COLUMN public.adi_user_day_cost.create_time IS 'Timestamp of record creation';
COMMENT ON COLUMN public.adi_user_day_cost.update_time IS 'Timestamp of record last update, automatically updated on each update';
COMMENT ON COLUMN public.adi_user_day_cost.secret_key_type IS '加密密钥类型';
COMMENT ON COLUMN public.adi_user_day_cost.images_number IS '图片数量';
-- update_time trigger
CREATE OR REPLACE FUNCTION update_modified_column()
RETURNS TRIGGER AS
$$
BEGIN
NEW.update_time = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
CREATE TRIGGER trigger_ai_image_update_time
BEFORE UPDATE
ON adi_ai_image
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_ai_model_update_time
BEFORE UPDATE
ON adi_ai_model
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_conv_update_time
BEFORE UPDATE
ON adi_conversation
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_conv_message_update_time
BEFORE UPDATE
ON adi_conversation_message
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_file_update_time
BEFORE UPDATE
ON adi_file
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_prompt_update_time
BEFORE UPDATE
ON adi_prompt
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_sys_config_update_time
BEFORE UPDATE
ON adi_sys_config
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_user_update_time
BEFORE UPDATE
ON adi_user
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
CREATE TRIGGER trigger_user_day_cost_update_time
BEFORE UPDATE
ON adi_user_day_cost
FOR EACH ROW
EXECUTE PROCEDURE update_modified_column();
INSERT INTO adi_sys_config (name, value)
VALUES ('secret_key', ''); VALUES ('secret_key', '');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}'); VALUES ('request_text_rate_limit', '{"times":24,"minutes":3}');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}'); VALUES ('request_image_rate_limit', '{"times":6,"minutes":3}');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('conversation_max_num', '50'); VALUES ('conversation_max_num', '50');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_token_daily', '10000'); VALUES ('quota_by_token_daily', '10000');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_token_monthly', '200000'); VALUES ('quota_by_token_monthly', '200000');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_request_daily', '150'); VALUES ('quota_by_request_daily', '150');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_request_monthly', '3000'); VALUES ('quota_by_request_monthly', '3000');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_image_daily', '30'); VALUES ('quota_by_image_daily', '30');
INSERT INTO `adi_sys_config` (`name`, `value`) INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_image_monthly', '300'); VALUES ('quota_by_image_monthly', '300');
INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_qa_ask_daily', '50');
INSERT INTO adi_sys_config (name, value)
VALUES ('quota_by_qa_item_monthly', '100');
CREATE TABLE `adi_conversation` create table adi_knowledge_base
( (
`id` bigint NOT NULL AUTO_INCREMENT, id bigserial primary key,
`user_id` bigint NOT NULL DEFAULT '0' COMMENT '用户id', uuid varchar(32) default ''::character varying not null,
`title` varchar(45) NOT NULL DEFAULT '' COMMENT '对话标题', title varchar(250) default ''::character varying not null,
`uuid` varchar(32) NOT NULL DEFAULT '', remark text default ''::character varying not null,
`understand_context_enable` tinyint NOT NULL default '0' COMMENT '是否开启上下文理解', is_public boolean default false not null,
`ai_model` varchar(45) NOT NULL DEFAULT '' COMMENT 'ai model', owner_id bigint default 0 not null,
`ai_system_message` varchar(1024) NOT NULL DEFAULT '' COMMENT 'set the system message to ai, ig: you are a lawyer', owner_name varchar(45) default ''::character varying not null,
`tokens` int NOT NULL DEFAULT '0' COMMENT '消耗token数量', create_time timestamp default CURRENT_TIMESTAMP not null,
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, update_time timestamp default CURRENT_TIMESTAMP not null,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, is_deleted boolean default false not null
`is_delete` tinyint NOT NULL DEFAULT '0', );
PRIMARY KEY (`id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='对话表';
CREATE TABLE `adi_conversation_message` comment on table adi_knowledge_base is '知识库';
comment on column adi_knowledge_base.title is '知识库名称';
comment on column adi_knowledge_base.remark is '知识库描述';
comment on column adi_knowledge_base.is_public is '是否公开';
comment on column adi_knowledge_base.owner_id is '所属人id';
comment on column adi_knowledge_base.owner_name is '所属人名称';
comment on column adi_knowledge_base.create_time is '创建时间';
comment on column adi_knowledge_base.update_time is '更新时间';
comment on column adi_knowledge_base.is_deleted is '0未删除1已删除';
create trigger trigger_kb_update_time
before update
on adi_knowledge_base
for each row
execute procedure update_modified_column();
create table adi_knowledge_base_item
( (
`id` bigint NOT NULL AUTO_INCREMENT, id bigserial primary key,
`parent_message_id` bigint NOT NULL DEFAULT '0' COMMENT '父级消息id', uuid varchar(32) default ''::character varying not null,
`conversation_id` bigint NOT NULL DEFAULT '0', kb_id bigint DEFAULT '0'::bigint NOT NULL,
`conversation_uuid` varchar(32) NOT NULL DEFAULT '', kb_uuid varchar(32) default ''::character varying not null,
`user_id` bigint NOT NULL DEFAULT '0' COMMENT 'User id', source_file_id bigint DEFAULT '0'::bigint NOT NULL,
`content` text NOT NULL COMMENT '对话的消息', title varchar(250) default ''::character varying not null,
`uuid` varchar(32) NOT NULL DEFAULT '', brief varchar(250) default ''::character varying not null,
`message_role` varchar(25) NOT NULL DEFAULT '' COMMENT '产生该消息的角色1:user,2:system,3:assistant', remark text default ''::character varying not null,
`tokens` int NOT NULL DEFAULT '0' COMMENT '消耗的token数量', is_embedded boolean default false not null,
`secret_key_type` int NOT NULL DEFAULT '1' COMMENT '1:System secret key,2:User secret key', create_time timestamp default CURRENT_TIMESTAMP not null,
`understand_context_msg_pair_num` int NOT NULL DEFAULT '0' COMMENT 'If context understanding enable, context_pair_msg_num > 0', update_time timestamp default CURRENT_TIMESTAMP not null,
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, is_deleted boolean default false not null
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, );
`is_delete` tinyint NOT NULL DEFAULT '0' COMMENT '是否删除,0:未删除,1:已删除',
PRIMARY KEY (`id`),
UNIQUE KEY `udx_uuid` (`uuid`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='对话信息表';
CREATE TABLE `adi_ai_image` comment on table adi_knowledge_base_item is '知识库-条目';
comment on column adi_knowledge_base_item.kb_id is '所属知识库id';
comment on column adi_knowledge_base_item.source_file_id is '来源文件id';
comment on column adi_knowledge_base_item.title is '条目标题';
comment on column adi_knowledge_base_item.brief is '条目内容摘要';
comment on column adi_knowledge_base_item.remark is '条目内容';
comment on column adi_knowledge_base_item.is_embedded is '是否已向量化,0:否,1:是';
comment on column adi_knowledge_base_item.create_time is '创建时间';
comment on column adi_knowledge_base_item.update_time is '更新时间';
comment on column adi_knowledge_base_item.is_deleted is '0未删除1已删除';
create trigger trigger_kb_item_update_time
before update
on adi_knowledge_base_item
for each row
execute procedure update_modified_column();
create table adi_knowledge_base_qa_record
( (
`id` bigint NOT NULL AUTO_INCREMENT, id bigserial primary key,
`user_id` bigint NOT NULL DEFAULT '0' COMMENT 'The user who generated the image', uuid varchar(32) default ''::character varying not null,
`uuid` varchar(32) NOT NULL DEFAULT '' COMMENT 'The uuid of the request of generated images', kb_id bigint DEFAULT '0'::bigint NOT NULL,
`prompt` varchar(1024) NOT NULL DEFAULT '' COMMENT 'The prompt for generating images', kb_uuid varchar(32) default ''::character varying not null,
`generate_size` varchar(20) NOT NULL DEFAULT '' COMMENT 'The size of the generated images. Must be one of "256x256", "512x512", or "1024x1024"', question varchar(1000) default ''::character varying not null,
`generate_number` int NOT NULL DEFAULT '1' COMMENT 'The number of images to generate. Must be between 1 and 10. Defaults to 1.', answer text default ''::character varying not null,
`original_image` varchar(32) NOT NULL DEFAULT '' COMMENT 'The original image uuid,interacting_method must be 2/3', source_file_ids varchar(500) default ''::character varying not null,
`mask_image` varchar(32) NOT NULL DEFAULT '' COMMENT 'The mask image uuid,interacting_method must be 2', user_id bigint default '0' NOT NULL,
`resp_images_path` varchar(2048) NOT NULL DEFAULT '' COMMENT 'The url of the generated images which from openai response,separated by commas', create_time timestamp default CURRENT_TIMESTAMP not null,
`generated_images` varchar(512) NOT NULL DEFAULT '' COMMENT 'The uuid of the generated images,separated by commas', update_time timestamp default CURRENT_TIMESTAMP not null,
`interacting_method` smallint NOT NULL DEFAULT '1' COMMENT '1:Creating images from scratch based on a text prompt;2:Creating edits of an existing image based on a new text prompt;3:Creating variations of an existing image', is_deleted boolean default false not null
`process_status` smallint NOT NULL DEFAULT '1' COMMENT 'Process status,1:processing,2:fail,3:success', );
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`is_delete` tinyint NOT NULL DEFAULT '0',
PRIMARY KEY (`id`),
UNIQUE KEY `udx_uuid` (`uuid`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='Images generated by ai';
CREATE TABLE `adi_user` comment on table adi_knowledge_base_qa_record is '知识库-提问记录';
(
`id` bigint NOT NULL AUTO_INCREMENT,
`name` varchar(45) NOT NULL DEFAULT '',
`password` varchar(120) NOT NULL DEFAULT '',
`uuid` varchar(32) NOT NULL DEFAULT '',
`email` varchar(120) NOT NULL DEFAULT '',
`active_time` datetime NULL COMMENT '激活时间',
`secret_key` varchar(120) NOT NULL default '' COMMENT 'Custom openai secret key',
`understand_context_msg_pair_num` int NOT NULL default '0' COMMENT '上下文理解中需要携带的消息对数量(提示词及回复)',
`quota_by_token_daily` int NOT NULL DEFAULT '0' COMMENT 'The quota of token daily',
`quota_by_token_monthly` int NOT NULL DEFAULT '0' COMMENT 'The quota of token monthly',
`quota_by_request_daily` int NOT NULL DEFAULT '0' COMMENT 'The quota of http request daily',
`quota_by_request_monthly` int NOT NULL DEFAULT '0' COMMENT 'The quota of http request monthly',
`quota_by_image_daily` int NOT NULL DEFAULT '0' COMMENT 'The quota of generate images daily',
`quota_by_image_monthly` int NOT NULL DEFAULT '0' COMMENT 'The quota of generate images monthly',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`user_status` smallint NOT NULL DEFAULT '1' COMMENT '用户状态1待验证2正常3冻结',
`is_delete` tinyint NOT NULL DEFAULT '0' COMMENT '0未删除1已删除',
PRIMARY KEY (`id`),
UNIQUE KEY `udx_uuid` (`uuid`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='用户表';
CREATE TABLE `adi_user_day_cost` comment on column adi_knowledge_base_qa_record.kb_id is '所属知识库id';
(
`id` bigint NOT NULL AUTO_INCREMENT,
`user_id` bigint NOT NULL DEFAULT '0',
`day` int NOT NULL DEFAULT '0' COMMENT '日期,用7位整数表示,如20230901',
`requests` int NOT NULL DEFAULT '0' COMMENT 'The number of http request',
`tokens` int NOT NULL DEFAULT '0' COMMENT 'The cost of the tokens',
`images_number` int NOT NULL DEFAULT '0' COMMENT 'The number of images',
`secret_key_type` int NOT NULL DEFAULT '1' COMMENT '1:System secret key,2:Custom secret key',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='用户每天消耗总量表';
comment on column adi_knowledge_base_qa_record.kb_uuid is '所属知识库uuid';
CREATE TABLE `adi_prompt` comment on column adi_knowledge_base_qa_record.question is '问题';
(
`id` bigint NOT NULL AUTO_INCREMENT,
`user_id` bigint NOT NULL DEFAULT '0' COMMENT '0:System,other:User',
`act` varchar(120) NOT NULL DEFAULT '' COMMENT 'Short prompt for search/autocomplete',
`prompt` varchar(1024) NOT NULL COMMENT 'Prompt content',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`is_delete` tinyint NOT NULL DEFAULT '0' COMMENT '0:Normal;1:Deleted',
PRIMARY KEY (`id`),
KEY `idx_title` (`act`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci COMMENT ='提示词';
CREATE TABLE `adi_file` comment on column adi_knowledge_base_qa_record.answer is '答案';
(
`id` bigint NOT NULL AUTO_INCREMENT, comment on column adi_knowledge_base_qa_record.source_file_ids is '来源文档id,以逗号隔开';
`name` varchar(32) NOT NULL DEFAULT '',
`uuid` varchar(32) NOT NULL DEFAULT '', comment on column adi_knowledge_base_qa_record.user_id is '提问用户id';
`md5` varchar(128) NOT NULL DEFAULT '',
`ext` varchar(32) NOT NULL DEFAULT '', comment on column adi_knowledge_base_qa_record.create_time is '创建时间';
`user_id` bigint NOT NULL DEFAULT '0' COMMENT '0:System,other:User',
`path` varchar(250) NOT NULL DEFAULT '', comment on column adi_knowledge_base_qa_record.update_time is '更新时间';
`ref_count` int NOT NULL DEFAULT '0' COMMENT 'The number of references to this file',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, comment on column adi_knowledge_base_qa_record.is_deleted is '0未删除1已删除';
`update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`is_delete` tinyint NOT NULL DEFAULT '0' COMMENT '0:Normal;1:Deleted', create trigger trigger_kb_qa_record_update_time
PRIMARY KEY (`id`), before update
KEY `idx_uuid` (`uuid`) on adi_knowledge_base_qa_record
) ENGINE = InnoDB for each row
DEFAULT CHARSET = utf8mb4 execute procedure update_modified_column();
COLLATE = utf8mb4_general_ci COMMENT ='提示词';

36
pom.xml
View File

@ -25,6 +25,7 @@
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target> <maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<langchain4j.version>0.25.0</langchain4j.version>
</properties> </properties>
<dependencies> <dependencies>
<dependency> <dependency>
@ -68,9 +69,8 @@
<version>31.1-jre</version> <version>31.1-jre</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mysql</groupId> <groupId>org.postgresql</groupId>
<artifactId>mysql-connector-j</artifactId> <artifactId>postgresql</artifactId>
<version>8.0.32</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId> <groupId>com.theokanning.openai-gpt3-java</groupId>
@ -130,6 +130,36 @@
<artifactId>velocity-engine-core</artifactId> <artifactId>velocity-engine-core</artifactId>
<version>2.3</version> <version>2.3</version>
</dependency> </dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai-spring-boot-starter</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-pgvector</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-document-parser-apache-pdfbox</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-document-parser-apache-poi</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId> <artifactId>spring-boot-starter-test</artifactId>