ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/KnowledgeInfo.java
@@ -1,5 +1,6 @@ package org.ruoyi.domain; import com.alibaba.excel.annotation.ExcelProperty; import com.baomidou.mybatisplus.annotation.*; import lombok.Data; import lombok.EqualsAndHashCode; @@ -78,14 +79,19 @@ private Long textBlockSize; /** * 向量库 * 向量库模型名称 */ private String vector; private String vectorModelName; /** * 向量模型 * 向量化模型名称 */ private String vectorModel; private String embeddingModelName; /** * 系统提示词 */ private String systemPrompt; /** * 备注 ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/KnowledgeInfoBo.java
@@ -83,16 +83,22 @@ private Long textBlockSize; /** * 向量库 * 向量库模型名称 */ @NotBlank(message = "向量库不能为空", groups = { AddGroup.class, EditGroup.class }) private String vector; private String vectorModelName; /** * 向量模型 * 向量化模型名称 */ @NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class }) private String vectorModel; private String embeddingModelName; /** * 系统提示词 */ private String systemPrompt; /** * 备注 ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/QueryVectorBo.java
@@ -26,9 +26,14 @@ private Integer maxResults; /** * 模型名称 * 向量库模型名称 */ private String modelName; private String vectorModelName; /** * 向量化模型名称 */ private String embeddingModelName; /** * 请求key ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
@@ -32,9 +32,14 @@ private List<String> fids; /** * 模型名称 * 向量库模型名称 */ private String modelName; private String vectorModelName; /** * 向量化模型名称 */ private String embeddingModelName; /** * 请求key ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/vo/KnowledgeInfoVo.java
@@ -98,16 +98,20 @@ private Integer textBlockSize; /** * 向量库 * 向量库模型名称 */ @ExcelProperty(value = "向量库") private String vector; private String vectorModelName; /** * 向量模型 * 向量化模型名称 */ @ExcelProperty(value = "向量模型") private String vectorModel; private String embeddingModelName; /** * 系统提示词 */ private String systemPrompt; /** * 备注 ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
@@ -13,14 +13,14 @@ void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo); void removeByDocId(String kid,String docId); void removeByKid(String kid); List<String> getQueryVector(QueryVectorBo queryVectorBo); void createSchema(String kid,String modelName); void removeByKidAndFid(String kid, String fid); void removeByKid(String kid,String modelName); void removeByDocId(String kid,String docId,String modelName); void removeByKidAndFid(String kid, String fid,String modelName); } ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java
@@ -1,5 +1,7 @@ package org.ruoyi.service.impl; import cn.hutool.core.util.RandomUtil; import com.google.protobuf.ServiceException; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -16,6 +18,7 @@ import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore; import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.core.service.ConfigService; import org.ruoyi.domain.bo.QueryVectorBo; @@ -40,11 +43,10 @@ private final ConfigService configService; Map<String,EmbeddingStore<TextSegment>> storeMap = new HashMap<>(); private EmbeddingStore<TextSegment> embeddingStore; @Override public void createSchema(String kid,String modelName) { EmbeddingStore<TextSegment> embeddingStore; switch (modelName) { case "weaviate" -> { String protocol = configService.getConfigValue("weaviate", "protocol"); @@ -84,88 +86,83 @@ embeddingStore = new InMemoryEmbeddingStore<>(); } } storeMap.put(kid,embeddingStore); } @Override public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { EmbeddingStore<TextSegment> store = storeMap.get(storeEmbeddingBo.getKid()); EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getModelName(), createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getVectorModelName()); EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); for (int i = 0; i < storeEmbeddingBo.getChunkList().size(); i++) { List<String> chunkList = storeEmbeddingBo.getChunkList(); for (int i = 0; i < chunkList.size(); i++) { Map<String, Object> dataSchema = new HashMap<>(); dataSchema.put("kid", storeEmbeddingBo.getKid()); dataSchema.put("docId", storeEmbeddingBo.getKid()); dataSchema.put("fid", storeEmbeddingBo.getFids().get(i)); Response<Embedding> response = embeddingModel.embed(storeEmbeddingBo.getChunkList().get(i)); Embedding embedding = response.content(); TextSegment segment = TextSegment.from(storeEmbeddingBo.getChunkList().get(i)); Embedding embedding = embeddingModel.embed(chunkList.get(i)).content(); TextSegment segment = TextSegment.from(chunkList.get(i)); segment.metadata().putAll(dataSchema); store.add(embedding,segment); embeddingStore.add(embedding,segment); } } @Override public List<String> getQueryVector(QueryVectorBo queryVectorBo) { EmbeddingStore<TextSegment> store = storeMap.get(queryVectorBo.getKid()); EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getModelName(), createSchema(queryVectorBo.getKid(),queryVectorBo.getVectorModelName()); EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid()); // Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() .queryEmbedding(queryEmbedding) .maxResults(queryVectorBo.getMaxResults()) // 添加过滤条件 .filter(simpleFilter) // .filter(simpleFilter) .build(); List<EmbeddingMatch<TextSegment>> matches = store.search(embeddingSearchRequest).matches(); List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches(); List<String> results = new ArrayList<>(); matches.forEach(embeddingMatch -> results.add(embeddingMatch.embedded().text())); return results; } @Override public void removeByKid(String kid) { EmbeddingStore<TextSegment> store = storeMap.get(kid); public void removeByKid(String kid,String modelName) { createSchema(kid,modelName); // 根据条件删除向量数据 Filter simpleFilter = new IsEqualTo("kid", kid); store.removeAll(simpleFilter); embeddingStore.removeAll(simpleFilter); } @Override public void removeByDocId(String kid, String docId) { EmbeddingStore<TextSegment> store = storeMap.get(kid); public void removeByDocId(String kid, String docId,String modelName) { createSchema(kid,modelName); // 根据条件删除向量数据 Filter simpleFilterByDocId = new IsEqualTo("docId", docId); store.removeAll(simpleFilterByDocId); embeddingStore.removeAll(simpleFilterByDocId); } @Override public void removeByKidAndFid(String kid, String fid) { EmbeddingStore<TextSegment> store = storeMap.get(kid); public void removeByKidAndFid(String kid, String fid,String modelName) { createSchema(kid,modelName); // 根据条件删除向量数据 Filter simpleFilterByKid = new IsEqualTo("kid", kid); Filter simpleFilterFid = new IsEqualTo("fid", fid); Filter simpleFilterByAnd = Filter.and(simpleFilterFid, simpleFilterByKid); store.removeAll(simpleFilterByAnd); embeddingStore.removeAll(simpleFilterByAnd); } /** * 获取向量模型 */ public EmbeddingModel getEmbeddingModel(String modelName,String apiKey,String baseUrl) { EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder().build(); @SneakyThrows public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) { EmbeddingModel embeddingModel; if(TEXT_EMBEDDING_3_SMALL.toString().equals(modelName)) { embeddingModel = OpenAiEmbeddingModel.builder() .apiKey(apiKey) .baseUrl(baseUrl) .modelName(TEXT_EMBEDDING_3_SMALL) .modelName(modelName) .build(); // TODO 添加枚举 }else if("quentinz/bge-large-zh-v1.5".equals(modelName)) { @@ -173,6 +170,14 @@ .baseUrl(baseUrl) .modelName(modelName) .build(); }else if("baai/bge-m3".equals(modelName)) { embeddingModel = OpenAiEmbeddingModel.builder() .apiKey(apiKey) .baseUrl(baseUrl) .modelName(modelName) .build(); }else { throw new ServiceException("未找到对应向量化模型!"); } return embeddingModel; } ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
@@ -2,6 +2,7 @@ import cn.dev33.satoken.stp.StpUtil; import cn.hutool.core.collection.CollectionUtil; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.google.protobuf.ServiceException; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; @@ -29,6 +30,8 @@ import org.ruoyi.domain.bo.ChatSessionBo; import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.vo.ChatModelVo; import org.ruoyi.domain.vo.KnowledgeInfoVo; import org.ruoyi.service.IKnowledgeInfoService; import org.ruoyi.service.VectorStoreService; import org.ruoyi.service.IChatModelService; import org.ruoyi.service.IChatSessionService; @@ -66,6 +69,8 @@ private final ChatServiceFactory chatServiceFactory; private final IChatSessionService chatSessionService; private final IKnowledgeInfoService knowledgeInfoService; private ChatModelVo chatModelVo; @@ -148,50 +153,61 @@ } } /** * 构建消息列表 */ private void buildChatMessageList(ChatRequest chatRequest){ chatModelVo = chatModelService.selectModelByName(chatRequest.getModel()); String sysPrompt; chatModelVo = chatModelService.selectModelByName(chatRequest.getModel()); // 获取对话消息列表 List<Message> messages = chatRequest.getMessages(); String sysPrompt = chatModelVo.getSystemPrompt(); // 查询向量库相关信息加入到上下文 if(StringUtils.isNotEmpty(chatRequest.getKid())){ List<Message> knMessages = new ArrayList<>(); String content = messages.get(messages.size() - 1).getContent().toString(); // 通过kid查询知识库信息 KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid())); // 查询向量模型配置信息 ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName()); if(StringUtils.isEmpty(sysPrompt)){ // TODO 系统默认提示词,后续会增加提示词管理 sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" + "当前时间:"+ DateUtils.getDate()+ "#注意:回复之前注意结合上下文和工具返回内容进行回复。"; QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(content); queryVectorBo.setKid(chatRequest.getKid()); queryVectorBo.setApiKey(chatModel.getApiKey()); queryVectorBo.setBaseUrl(chatModel.getApiHost()); queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName()); queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName()); queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo); for (String prompt : nearestList) { Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); knMessages.add(userMessage); } messages.addAll(knMessages); // 设置知识库系统提示词 sysPrompt = knowledgeInfoVo.getSystemPrompt(); if(StringUtils.isEmpty(sysPrompt)){ sysPrompt ="###角色设定\n" + "你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" + "###指令\n" + "当用户的问题与上下文知识匹配时,利用上下文信息进行回答。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" + "###限制\n" + "确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好" + "当前时间:"+ DateUtils.getDate(); } }else { sysPrompt = chatModelVo.getSystemPrompt(); if(StringUtils.isEmpty(sysPrompt)){ sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" + "当前时间:"+ DateUtils.getDate()+ "#注意:回复之前注意结合上下文和工具返回内容进行回复。"; } } // 设置系统默认提示词 Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build(); messages.add(0,sysMessage); chatRequest.setSysPrompt(sysPrompt); // 查询向量库相关信息加入到上下文 if(StringUtils.isNotEmpty(chatRequest.getKid())){ List<Message> knMessages = new ArrayList<>(); String content = messages.get(messages.size() - 1).getContent().toString(); QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(content); queryVectorBo.setKid(chatRequest.getKid()); queryVectorBo.setApiKey(chatModelVo.getApiKey()); queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); queryVectorBo.setModelName(chatModelVo.getModelName()); // TODO 查询向量返回条数,这里应该查询知识库配置 queryVectorBo.setMaxResults(3); List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo); for (String prompt : nearestList) { Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); knMessages.add(userMessage); } // TODO 提示词,这里应该查询知识库配置 Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build(); knMessages.add(userMessage); messages.addAll(knMessages); } // 用户对话内容 String chatString = null; // 获取用户对话信息 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
@@ -102,8 +102,6 @@ lqw.eq(bo.getOverlapChar() != null, KnowledgeInfo::getOverlapChar, bo.getOverlapChar()); lqw.eq(bo.getRetrieveLimit() != null, KnowledgeInfo::getRetrieveLimit, bo.getRetrieveLimit()); lqw.eq(bo.getTextBlockSize() != null, KnowledgeInfo::getTextBlockSize, bo.getTextBlockSize()); lqw.eq(StringUtils.isNotBlank(bo.getVector()), KnowledgeInfo::getVector, bo.getVector()); lqw.eq(StringUtils.isNotBlank(bo.getVectorModel()), KnowledgeInfo::getVectorModel, bo.getVectorModel()); return lqw; } @@ -161,7 +159,7 @@ } baseMapper.insert(knowledgeInfo); if (knowledgeInfo != null) { vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()),bo.getVector()); vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()),bo.getVectorModelName()); } }else { baseMapper.updateById(knowledgeInfo); @@ -177,7 +175,7 @@ check(knowledgeInfoList); // 删除向量库信息 knowledgeInfoList.forEach(knowledgeInfoVo -> { vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId())); vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId()),knowledgeInfoVo.getVectorModelName()); }); // 删除附件和知识片段 fragmentMapper.deleteByMap(map); @@ -231,17 +229,18 @@ // 通过kid查询知识库信息 KnowledgeInfoVo knowledgeInfoVo = baseMapper.selectVoOne(Wrappers.<KnowledgeInfo>lambdaQuery() .eq(KnowledgeInfo::getKid, kid)); .eq(KnowledgeInfo::getId, kid)); // 通过向量模型查询模型信息 ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getVectorModel()); ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName()); StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo(); storeEmbeddingBo.setKid(kid); storeEmbeddingBo.setDocId(docId); storeEmbeddingBo.setFids(fids); storeEmbeddingBo.setChunkList(chunkList); storeEmbeddingBo.setModelName(knowledgeInfoVo.getVectorModel()); storeEmbeddingBo.setVectorModelName(knowledgeInfoVo.getVectorModelName()); storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName()); storeEmbeddingBo.setApiKey(chatModelVo.getApiKey()); storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost()); vectorStoreService.storeEmbeddings(storeEmbeddingBo);