办学质量监测教学评价系统
ageerle
2025-04-11 efeb0bd6fb60a6216cb3626df204100546466b23
fix: ollama兼容联网查询 知识库检索
已修改3个文件
159 ■■■■■ 文件已修改
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java 152 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java
@@ -466,8 +466,8 @@
     * @since 1.1.3
     */
    public ResponseBody textToSpeech(TextToSpeech textToSpeech){
        Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
        try {
            Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
            return responseBody.execute().body();
        } catch (IOException e) {
            throw new BaseException("文本转语音(同步)失败: "+e.getMessage());
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java
@@ -27,6 +27,11 @@
    private String prompt;
    /**
     * 系统提示词
     */
    private String sysPrompt;
    /**
     * 是否开启流式对话
     */
    private Boolean stream = Boolean.TRUE;
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
@@ -84,11 +84,9 @@
    private final IChatCostService chatCostService;
    private static final String requestIdTemplate = "mycompany-%d";
    private static final String requestIdTemplate = "company-%d";
    private static final ObjectMapper mapper = new ObjectMapper();
    private OpenAiStreamClient openAiModelStreamClient;
    @Override
    public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
@@ -96,65 +94,33 @@
        SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
        // 获取对话消息列表
        List<Message> messages = chatRequest.getMessages();
        // 用户对话内容
        String chatString = null;
        try {
            if (StpUtil.isLogin()) {
                // 通过模型名称查询模型信息
                ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
                if(chatModelVo!=null){
                    // 通过模型信息构建请求客户端
                    openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
                }else {
                    // 使用默认客户端
                    openAiModelStreamClient  = openAiStreamClient;
                }
            // 查询模型信息
            ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
            OpenAiStreamClient openAiModelStreamClient;
            if(chatModelVo!=null){
                // 建请求客户端
                openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
                // 设置默认提示词
                Message sysMessage = Message.builder().content(chatModelVo.getSystemPrompt()).role(Message.Role.SYSTEM).build();
                messages.add(0,sysMessage);
                // 查询向量库相关信息加入到上下文
                if(chatRequest.getKid()!=null){
                    List<Message> knMessages = new ArrayList<>();
                    String content = messages.get(messages.size() - 1).getContent().toString();
                    List<String> nearestList;
                    List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
                    nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
                    for (String prompt : nearestList) {
                        Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
                        knMessages.add(userMessage);
                    }
                    Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
                    knMessages.add(userMessage);
                    messages.addAll(knMessages);
                }
                // 获取用户对话信息
                Object content = messages.get(messages.size() - 1).getContent();
                if (content instanceof List<?> listContent) {
                    if (CollectionUtil.isNotEmpty(listContent)) {
                        chatString = listContent.get(0).toString();
                    }
                } else if (content instanceof String) {
                    chatString = (String) content;
                }
                // 加载联网信息
                if(chatRequest.getSearch()){
                    Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
                    messages.add(message);
                }
                chatRequest.setSysPrompt(chatModelVo.getSystemPrompt());
            }else {
                // 未登录用户限制对话次数
                String clientIp = IpUtil.getClientIp(request);
                // 使用默认客户端
                openAiModelStreamClient = openAiStreamClient;
            }
            // 构建消息列表增加联网、知识库等内容
            buildChatMessageList(chatRequest);
            // 根据模型名称前缀调用不同的处理逻辑
            switchModelAndHandle(chatRequest);
            // 未登录用户限制对话次数
            if (!StpUtil.isLogin()) {
                String clientIp = IpUtil.getClientIp(request);
                // 访客每天默认只能对话5次
                int timeWindowInSeconds = 5;
                String redisKey = "clientIp:" + clientIp;
                int count = 0;
                if (RedisUtils.getCacheObject(redisKey) == null) {
                    // 缓存有效时间1天
                    RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
@@ -175,6 +141,7 @@
                    .stream(chatRequest.getStream())
                    .build();
            openAiModelStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
            // 保存消息记录 并扣除费用
            chatCostService.deductToken(chatRequest);
        } catch (Exception e) {
@@ -185,6 +152,69 @@
        return sseEmitter;
    }
    /**
     *  根据模型名称前缀调用不同的处理逻辑
     */
    private void switchModelAndHandle(ChatRequest chatRequest) {
        String model = chatRequest.getModel();
        // 如果模型名称以ollama开头,则调用ollama中部署的本地模型
        if (model.startsWith("ollama-")) {
            String[] parts = chatRequest.getModel().split("ollama-", 2); // 限制分割次数为2
            if (parts.length > 1) {
                chatRequest.setModel(parts[1]);
                ollamaChat(chatRequest);
            } else {
                throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
            }
        } else if (model.startsWith("gpt-4-gizmo")) {
            chatRequest.setModel("gpt-4-gizmo");
        }
    }
    /**
     *  构建消息列表
     */
    private void buildChatMessageList(ChatRequest chatRequest){
        // 获取对话消息列表
        List<Message> messages = chatRequest.getMessages();
        // 设置系统默认提示词
        Message sysMessage = Message.builder().content(chatRequest.getSysPrompt()).role(Message.Role.SYSTEM).build();
        messages.add(0,sysMessage);
        // 查询向量库相关信息加入到上下文
        if(chatRequest.getKid()!=null){
            List<Message> knMessages = new ArrayList<>();
            String content = messages.get(messages.size() - 1).getContent().toString();
            List<String> nearestList;
            List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
            nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
            for (String prompt : nearestList) {
                Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
                knMessages.add(userMessage);
            }
            Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
            knMessages.add(userMessage);
            messages.addAll(knMessages);
        }
        // 用户对话内容
        String chatString = null;
        // 获取用户对话信息
        Object content = messages.get(messages.size() - 1).getContent();
        if (content instanceof List<?> listContent) {
            if (CollectionUtil.isNotEmpty(listContent)) {
                chatString = listContent.get(0).toString();
            }
        } else if (content instanceof String) {
            chatString = (String) content;
        }
        // 设置对话信息
        chatRequest.setPrompt(chatString);
        // 加载联网信息
        if(chatRequest.getSearch()){
            Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
            messages.add(message);
        }
    }
    /**
     * 发送SSE错误事件的封装方法
@@ -295,13 +325,13 @@
    @Override
    public SseEmitter ollamaChat(ChatRequest chatRequest) {
        String[] parts = chatRequest.getModel().split("ollama-");
        ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
        final SseEmitter emitter = new SseEmitter();
        String host = chatModelVo.getApiHost();
        List<Message> msgList = chatRequest.getMessages();
        List<OllamaChatMessage> messages = new ArrayList<>();
        List<OllamaChatMessage> messages = new ArrayList<>();
        for (Message message : msgList) {
            OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
            ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
@@ -310,7 +340,7 @@
        }
        OllamaAPI api = new OllamaAPI(host);
        api.setRequestTimeoutSeconds(100);
        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatRequest.getModel());
        OllamaChatRequestModel requestModel = builder
            .withMessages(messages)
@@ -356,11 +386,11 @@
    @Override
    public String webSearch (String prompt) {
        String zhipuValue = configService.getConfigValue("zhipu", "key");
        if(StringUtils.isEmpty(zhipuValue)){
            throw new IllegalStateException("zhipu config value is empty,请在chat_config中配置zhipu key信息");
        String zpValue = configService.getConfigValue("zhipu", "key");
        if(StringUtils.isEmpty(zpValue)){
            throw new IllegalStateException("请在chat_config中配置智谱key信息");
        }else {
            ClientV4 client = new ClientV4.Builder(zhipuValue)
            ClientV4 client = new ClientV4.Builder(zpValue)
                    .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
                    .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
                    .build();