办学质量监测教学评价系统
ageerle
2025-04-14 188dc1e55e3abbfb00397d67e2b4eed52cead356
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
@@ -11,11 +11,13 @@
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.chat.config.ChatConfig;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.chat.util.IpUtil;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.common.chat.config.LocalCache;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
@@ -33,7 +35,9 @@
import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.EmbeddingService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.VectorStoreService;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
@@ -74,27 +78,35 @@
    private final IChatService chatService;
    private final IChatModelService chatModelService;
    private static final String requestIdTemplate = "company-%d";
    private static final ObjectMapper mapper = new ObjectMapper();
    private final ChatConfig chatConfig;
    @Override
    public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
        SseEmitter sseEmitter = new SseEmitter(0L);
        SseEmitter sseEmitter = new SseEmitter();
        try {
            // 构建消息列表增加联网、知识库等内容
            buildChatMessageList(chatRequest);
            if (!StpUtil.isLogin()) {
                // 未登录用户限制对话次数
                checkUnauthenticatedUserChatLimit(request);
            }else {
                LocalCache.CACHE.put("userId", chatCostService.getUserId());
                chatRequest.setUserId(chatCostService.getUserId());
                // 保存消息记录 并扣除费用
                // chatCostService.deductToken(chatRequest);
            }
            // 根据模型名称前缀调用不同的处理逻辑
            switchModelAndHandle(chatRequest,sseEmitter);
            // 未登录用户限制对话次数
            checkUnauthenticatedUserChatLimit(request);
            // 保存消息记录 并扣除费用
            chatCostService.deductToken(chatRequest);
        } catch (Exception e) {
            String message = e.getMessage();
            SSEUtil.sendErrorEvent(sseEmitter, message);
            return sseEmitter;
            log.error(e.getMessage(),e);
            sseEmitter.completeWithError(e);
        }
        return sseEmitter;
    }
@@ -106,8 +118,7 @@
     * @throws ServiceException 如果当日免费次数已用完
     */
    public void checkUnauthenticatedUserChatLimit(HttpServletRequest request) throws ServiceException {
        // 未登录用户限制对话次数
        if (!StpUtil.isLogin()) {
            String clientIp = IpUtil.getClientIp(request);
            // 访客每天默认只能对话5次
            int timeWindowInSeconds = 5;
@@ -125,13 +136,14 @@
                count++;
                RedisUtils.setCacheObject(redisKey, count);
            }
        }
    }
    /**
     *  根据模型名称前缀调用不同的处理逻辑
     */
    private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) {
        SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter);
        String model = chatRequest.getModel();
        // 如果模型名称以ollama开头,则调用ollama中部署的本地模型
        if (model.startsWith("ollama-")) {
@@ -142,8 +154,24 @@
            } else {
                throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
            }
        } else if (model.startsWith("gpt-4-gizmo")) {
            chatRequest.setModel("gpt-4-gizmo");
        } else {
            if (model.startsWith("gpt-4-gizmo")) {
                chatRequest.setModel("gpt-4-gizmo");
            }
            ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
            //openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
            ChatCompletion completion = ChatCompletion
                    .builder()
                    .messages(chatRequest.getMessages())
                    .model(chatRequest.getModel())
                    .temperature(0.2)
                    .topP(1.0)
                    .stream(true)
                    .build();
            openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
        }
    }
@@ -151,9 +179,10 @@
     *  构建消息列表
     */
    private void buildChatMessageList(ChatRequest chatRequest){
        ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
        // 获取对话消息列表
        List<Message> messages = chatRequest.getMessages();
        String sysPrompt = chatRequest.getSysPrompt();
        String sysPrompt = chatModelVo.getSystemPrompt();
        if(StringUtils.isEmpty(sysPrompt)){
            sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
                    "当前时间:"+ DateUtils.getDate();
@@ -162,8 +191,9 @@
        Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
        messages.add(0,sysMessage);
        chatRequest.setSysPrompt(sysPrompt);
        // 查询向量库相关信息加入到上下文
        if(chatRequest.getKid()!=null){
        if(StringUtils.isNotEmpty(chatRequest.getKid())){
            List<Message> knMessages = new ArrayList<>();
            String content = messages.get(messages.size() - 1).getContent().toString();
            List<String> nearestList;