办学质量监测教学评价系统
ageer
2025-03-12 d8fda1559351c61678738285041da6811b463e53
ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java
@@ -1,8 +1,10 @@
package org.ruoyi.system.service.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.tools.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
@@ -17,10 +19,7 @@
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.domain.request.Dall3Request;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Content;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.chat.*;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.images.Image;
import org.ruoyi.common.chat.entity.images.ImageResponse;
@@ -33,17 +32,15 @@
import org.ruoyi.common.chat.plugin.CmdReq;
import org.ruoyi.common.chat.plugin.SqlPlugin;
import org.ruoyi.common.chat.plugin.SqlReq;
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
import org.ruoyi.common.chat.utils.TikTokensUtil;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.system.domain.SysModel;
import org.ruoyi.system.domain.bo.ChatMessageBo;
import org.ruoyi.system.domain.bo.SysModelBo;
import org.ruoyi.system.domain.request.translation.TranslationRequest;
import org.ruoyi.system.domain.vo.SysModelVo;
import org.ruoyi.system.listener.SSEEventSourceListener;
import org.ruoyi.system.service.*;
import org.springframework.core.io.InputStreamResource;
@@ -65,6 +62,9 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@Service
@@ -76,17 +76,20 @@
    private final ChatConfig chatConfig;
    private final IChatCostService chatService;
    private final IChatMessageService chatMessageService;
    private final ISysModelService sysModelService;
    private final ISysUserService userService;
    private final ConfigService configService;
    static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
    private static final String requestIdTemplate = "mycompany-%d";
    private static final ObjectMapper mapper = new ObjectMapper();
    @Override
    public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
@@ -96,11 +99,10 @@
        // 获取对话消息列表
        List<Message> messages = chatRequest.getMessages();
        try {
            String chatString = null;
            if (StpUtil.isLogin()) {
                LocalCache.CACHE.put("userId", getUserId());
                Object content = messages.get(messages.size() - 1).getContent();
                String chatString = "";
                if (content instanceof List<?> listContent) {
                    if (!listContent.isEmpty() && listContent.get(0) instanceof Content) {
                        chatString = ((Content) listContent.get(0)).getText();
@@ -123,39 +125,89 @@
                        throw new BaseException("文本不合规,请修改!");
                    }
                }
                //根据模型名称查询模型信息
                SysModelBo sysModelBo = new SysModelBo();
                String model = chatRequest.getModel();
                // 如果是gpts系列模型
                if (chatRequest.getModel().startsWith("gpt-4-gizmo")) {
                    sysModelBo.setModelName("gpt-4-gizmo");
                } else {
                    sysModelBo.setModelName(chatRequest.getModel());
                    model = "gpt-4-gizmo";
                }
                List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
                if (CollectionUtil.isEmpty(sysModelList)) {
                SysModel sysModel = sysModelService.selectModelByName(model);
                if (sysModel != null) {
                    // 如果模型不存在默认使用token扣费方式
                    processByToken(chatRequest.getModel(), chatString, chatMessageBo);
                } else {
                    openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModelList.get(0).getApiHost(), sysModelList.get(0).getApiKey());
                    openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModel.getApiHost(), sysModel.getApiKey());
                    // 模型设置默认提示词
                    SysModelVo firstModel = sysModelList.get(0);
                    if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) {
                        Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
                    if (StringUtils.isNotEmpty(sysModel.getSystemPrompt())) {
                        Message sysMessage = Message.builder().content(sysModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
                        messages.add(sysMessage);
                    }
                    // 计费类型: 1 token扣费 2 次数扣费
                    if ("2".equals(firstModel.getModelType())) {
                        processByModelPrice(firstModel, chatMessageBo);
                    if ("2".equals(sysModel.getModelType())) {
                        processByModelPrice(sysModel, chatMessageBo);
                    } else {
                       processByToken(chatRequest.getModel(), chatString, chatMessageBo);
                        processByToken(chatRequest.getModel(), chatString, chatMessageBo);
                    }
                }
            }
            if("openCmd".equals(chatRequest.getModel())) {
            String configValue = configService.getConfigValue("zhipu", "key");
            // 添加联网信息
            if(StringUtils.isNotEmpty(configValue)){
                ClientV4 client = new ClientV4.Builder(configValue)
                        .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
                        .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
                        .build();
                SearchChatMessage jsonNodes = new SearchChatMessage();
                jsonNodes.setRole(Message.Role.USER.getName());
                jsonNodes.setContent(chatString);
                String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
                WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
                        .model("web-search-pro")
                        .stream(Boolean.TRUE)
                        .messages(Collections.singletonList(jsonNodes))
                        .requestId(requestId)
                        .build();
                WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
                List<ChoiceDelta> choices = new ArrayList<>();
                if (webSearchApiResponse.isSuccess()) {
                    AtomicBoolean isFirst = new AtomicBoolean(true);
                    AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
                    webSearchApiResponse.getFlowable().map(result -> result)
                            .doOnNext(accumulator -> {
                                {
                                    if (isFirst.getAndSet(false)) {
                                        log.info("Response: ");
                                    }
                                    ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
                                    if (delta != null && delta.getToolCalls() != null) {
                                        log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
                                    }
                                    choices.add(delta);
                                }
                            })
                            .doOnComplete(() -> System.out.println("Stream completed."))
                            .doOnError(throwable -> System.err.println("Error: " + throwable))
                            .blockingSubscribe();
                    WebSearchPro chatMessageAccumulator = lastAccumulator.get();
                    webSearchApiResponse.setFlowable(null);// 打印前置空
                    webSearchApiResponse.setData(chatMessageAccumulator);
                }
                Message message = Message.builder().role(Message.Role.ASSISTANT).content(choices.get(1).getToolCalls().toString()).build();
                messages.add(message);
            }
            if ("openCmd".equals(chatRequest.getModel())) {
                sseEmitter.send(cmdPlugin(messages));
                sseEmitter.complete();
            }else if ("sqlPlugin".equals(chatRequest.getModel())){
            } else if ("sqlPlugin".equals(chatRequest.getModel())) {
                sseEmitter.send(sqlPlugin(messages));
                sseEmitter.complete();
            } else {
@@ -229,7 +281,7 @@
     * @param model         模型信息
     * @param chatMessageBo 对话信息
     */
    private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) {
    private void processByModelPrice(SysModel model, ChatMessageBo chatMessageBo) {
        double cost = model.getModelPrice();
        chatService.deductUserBalance(getUserId(), cost);
        chatMessageBo.setDeductCost(cost);
@@ -316,16 +368,14 @@
            .style(request.getStyle())
            .build();
        ImageResponse imageResponse = openAiStreamClient.genImages(image);
        SysModelBo sysModelBo = new SysModelBo();
        sysModelBo.setModelName(request.getModel());
        List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
        SysModel sysModel = sysModelService.selectModelByName(request.getModel());
        //chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice());
        // 保存消息记录
        ChatMessageBo chatMessageBo = new ChatMessageBo();
        chatMessageBo.setUserId(getUserId());
        chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
        chatMessageBo.setContent(request.getPrompt());
        chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
        chatMessageBo.setDeductCost(sysModel.getModelPrice());
        chatMessageBo.setTotalTokens(0);
        chatMessageService.insertByBo(chatMessageBo);
        return imageResponse.getData();
@@ -342,16 +392,14 @@
            .n(1)
            .build();
        ImageResponse imageResponse = openAiStreamClient.genImages(image);
        SysModelBo sysModelBo = new SysModelBo();
        sysModelBo.setModelName("dall3");
        List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
        SysModel dall3 = sysModelService.selectModelByName("dall3");
        chatService.deductUserBalance(Long.valueOf(userId), 0.3);
        // 保存消息记录
        ChatMessageBo chatMessageBo = new ChatMessageBo();
        chatMessageBo.setUserId(getUserId());
        chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
        chatMessageBo.setContent(prompt);
        chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
        chatMessageBo.setDeductCost(dall3.getModelPrice());
        chatMessageBo.setTotalTokens(0);
        chatMessageService.insertByBo(chatMessageBo);
        return imageResponse.getData();
@@ -527,12 +575,9 @@
        chatMessageBo.setDeductCost(0.01);
        chatMessageBo.setTotalTokens(0);
        chatMessageService.insertByBo(chatMessageBo);
        openAiStreamClient = chatConfig.getOpenAiStreamClient();
        List<Message> messageList = new ArrayList<>();
        Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一名翻译老师\n" +
        Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一位精通各国语言的翻译大师\n" +
            "\n" +
            "请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
            "\n" +
@@ -563,25 +608,21 @@
    @Override
    public SseEmitter ollamaChat(ChatRequest chatRequest) {
        String[] parts = chatRequest.getModel().split("ollama-");
        SysModel sysModel = sysModelService.selectModelByName(parts[1]);
        final SseEmitter emitter = new SseEmitter();
        String host = "http://localhost:11434/";
        String host = sysModel.getApiHost();
        List<Message> msgList = chatRequest.getMessages();
        Message message = msgList.get(msgList.size() - 1);
        OllamaAPI ollamaAPI = new OllamaAPI(host);
        ollamaAPI.setRequestTimeoutSeconds(100);
        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("qwen2.5:7b");
        OllamaAPI api = new OllamaAPI(host);
        api.setRequestTimeoutSeconds(100);
        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(sysModel.getModelName());
        OllamaChatRequestModel requestModel = builder
            .withMessage(OllamaChatMessageRole.USER,
                message.getContent().toString())
            .build();
        // 异步执行 Ollama API 调用
        // 异步执行 OllAma API 调用
        CompletableFuture.runAsync(() -> {
            try {
                StringBuilder response = new StringBuilder();
@@ -595,14 +636,12 @@
                        sendErrorEvent(emitter, e.getMessage());
                    }
                };
                ollamaAPI.chat(requestModel, streamHandler);
                api.chat(requestModel, streamHandler);
                emitter.complete();
            } catch (Exception e) {
                sendErrorEvent(emitter, e.getMessage());
            }
        });
        return emitter;
    }
@@ -620,6 +659,4 @@
        ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
        return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
    }
}