| | |
| | | package org.ruoyi.controller; |
| | | |
| | | import cn.dev33.satoken.stp.StpUtil; |
| | | import jakarta.servlet.http.HttpServletRequest; |
| | | import jakarta.servlet.http.HttpServletResponse; |
| | | import jakarta.validation.Valid; |
| | | import jakarta.validation.constraints.NotEmpty; |
| | | import jakarta.validation.constraints.NotNull; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.ruoyi.common.chat.config.ChatConfig; |
| | | import org.ruoyi.common.chat.domain.request.ChatRequest; |
| | | import org.ruoyi.common.chat.entity.chat.ChatCompletion; |
| | | import org.ruoyi.common.chat.entity.chat.Message; |
| | | import org.ruoyi.common.chat.openai.OpenAiStreamClient; |
| | | import org.ruoyi.common.core.domain.R; |
| | | import org.ruoyi.common.core.validate.AddGroup; |
| | | import org.ruoyi.common.excel.utils.ExcelUtil; |
| | |
| | | import org.ruoyi.common.mybatis.core.page.TableDataInfo; |
| | | import org.ruoyi.common.satoken.utils.LoginHelper; |
| | | import org.ruoyi.common.web.core.BaseController; |
| | | import org.ruoyi.knowledge.chain.vectorstore.VectorStore; |
| | | import org.ruoyi.knowledge.domain.bo.KnowledgeAttachBo; |
| | | import org.ruoyi.knowledge.domain.bo.KnowledgeFragmentBo; |
| | | import org.ruoyi.knowledge.domain.bo.KnowledgeInfoBo; |
| | |
| | | import org.ruoyi.knowledge.service.IKnowledgeAttachService; |
| | | import org.ruoyi.knowledge.service.IKnowledgeFragmentService; |
| | | import org.ruoyi.knowledge.service.IKnowledgeInfoService; |
| | | import org.ruoyi.system.listener.SSEEventSourceListener; |
| | | import org.ruoyi.system.service.ISseService; |
| | | import org.springframework.validation.annotation.Validated; |
| | | import org.springframework.web.bind.annotation.*; |
| | | import org.ruoyi.knowledge.chain.vectorstore.VectorStore; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | |
| | | import java.util.List; |
| | |
| | | |
| | | private final EmbeddingService embeddingService; |
| | | |
| | | private OpenAiStreamClient openAiStreamClient; |
| | | |
| | | private final ChatConfig chatConfig; |
| | | |
| | | private final ISseService sseService; |
| | | |
| | | /** |
| | | * 知识库对话 |
| | | */ |
| | | @PostMapping("/send") |
| | | public SseEmitter send(@RequestBody @Valid ChatRequest chatRequest) { |
| | | |
| | | openAiStreamClient = chatConfig.getOpenAiStreamClient(); |
| | | SseEmitter sseEmitter = new SseEmitter(0L); |
| | | SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter); |
| | | public SseEmitter send(@RequestBody @Valid ChatRequest chatRequest, HttpServletRequest request) { |
| | | List<Message> messages = chatRequest.getMessages(); |
| | | String content = messages.get(messages.size() - 1).getContent().toString(); |
| | | // 获取知识库信息 |
| | | Message message = messages.get(messages.size() - 1); |
| | | StringBuilder sb = new StringBuilder(message.getContent().toString()); |
| | | List<String> nearestList; |
| | | List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid()); |
| | | nearestList = vectorStore.nearest(queryVector,chatRequest.getKid()); |
| | | List<Double> queryVector = embeddingService.getQueryVector(message.getContent().toString(), chatRequest.getKid()); |
| | | nearestList = vectorStore.nearest(queryVector, chatRequest.getKid()); |
| | | for (String prompt : nearestList) { |
| | | Message sysMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); |
| | | messages.add(sysMessage); |
| | | sb.append("\n####").append(prompt); |
| | | } |
| | | Message userMessage = Message.builder().content(content + (nearestList.size() > 0 ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "") ).role(Message.Role.USER).build(); |
| | | messages.add(userMessage); |
| | | if (chatRequest.getModel().startsWith("ollama")) { |
| | | return sseService.ollamaChat(chatRequest); |
| | | } |
| | | |
| | | ChatCompletion completion = ChatCompletion |
| | | .builder() |
| | | .messages(messages) |
| | | .model(chatRequest.getModel()) |
| | | .temperature(chatRequest.getTemperature()) |
| | | .topP(chatRequest.getTop_p()) |
| | | .stream(true) |
| | | .build(); |
| | | openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); |
| | | |
| | | return sseEmitter; |
| | | sb.append( (nearestList.size() > 0 ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")); |
| | | message.setRole(Message.Role.USER.getName()); |
| | | message.setContent(sb.toString()); |
| | | return sseService.sseChat(chatRequest, request); |
| | | } |
| | | |
| | | /** |