| | |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import okhttp3.*; |
| | | |
| | | import org.ruoyi.chat.config.ChatConfig; |
| | | import org.ruoyi.chat.listener.SSEEventSourceListener; |
| | | import org.ruoyi.chat.service.chat.IChatCostService; |
| | | import org.ruoyi.chat.service.chat.IChatService; |
| | | import org.ruoyi.chat.service.chat.ISseService; |
| | | import org.ruoyi.chat.util.IpUtil; |
| | | import org.ruoyi.chat.util.SSEUtil; |
| | | import org.ruoyi.common.chat.config.LocalCache; |
| | | import org.ruoyi.common.chat.request.ChatRequest; |
| | | import org.ruoyi.common.chat.entity.Tts.TextToSpeech; |
| | | import org.ruoyi.common.chat.entity.chat.ChatCompletion; |
| | |
| | | |
| | | import org.ruoyi.common.redis.utils.RedisUtils; |
| | | |
| | | import org.ruoyi.domain.vo.ChatModelVo; |
| | | import org.ruoyi.service.EmbeddingService; |
| | | import org.ruoyi.service.IChatModelService; |
| | | import org.ruoyi.service.VectorStoreService; |
| | | import org.springframework.core.io.InputStreamResource; |
| | | import org.springframework.core.io.Resource; |
| | |
| | | |
| | | private final IChatService chatService; |
| | | |
| | | private final IChatModelService chatModelService; |
| | | |
| | | private static final String requestIdTemplate = "company-%d"; |
| | | |
| | | private static final ObjectMapper mapper = new ObjectMapper(); |
| | | |
| | | private final ChatConfig chatConfig; |
| | | |
| | | @Override |
| | | public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) { |
| | | SseEmitter sseEmitter = new SseEmitter(0L); |
| | | SseEmitter sseEmitter = new SseEmitter(); |
| | | try { |
| | | // 构建消息列表增加联网、知识库等内容 |
| | | buildChatMessageList(chatRequest); |
| | | if (!StpUtil.isLogin()) { |
| | | // 未登录用户限制对话次数 |
| | | checkUnauthenticatedUserChatLimit(request); |
| | | }else { |
| | | LocalCache.CACHE.put("userId", chatCostService.getUserId()); |
| | | |
| | | chatRequest.setUserId(chatCostService.getUserId()); |
| | | // 保存消息记录 并扣除费用 |
| | | // chatCostService.deductToken(chatRequest); |
| | | } |
| | | // 根据模型名称前缀调用不同的处理逻辑 |
| | | switchModelAndHandle(chatRequest,sseEmitter); |
| | | // 未登录用户限制对话次数 |
| | | checkUnauthenticatedUserChatLimit(request); |
| | | // 保存消息记录 并扣除费用 |
| | | chatCostService.deductToken(chatRequest); |
| | | } catch (Exception e) { |
| | | String message = e.getMessage(); |
| | | SSEUtil.sendErrorEvent(sseEmitter, message); |
| | | return sseEmitter; |
| | | log.error(e.getMessage(),e); |
| | | sseEmitter.completeWithError(e); |
| | | } |
| | | return sseEmitter; |
| | | } |
| | |
| | | * @throws ServiceException 如果当日免费次数已用完 |
| | | */ |
| | | public void checkUnauthenticatedUserChatLimit(HttpServletRequest request) throws ServiceException { |
| | | // 未登录用户限制对话次数 |
| | | if (!StpUtil.isLogin()) { |
| | | |
| | | String clientIp = IpUtil.getClientIp(request); |
| | | // 访客每天默认只能对话5次 |
| | | int timeWindowInSeconds = 5; |
| | |
| | | count++; |
| | | RedisUtils.setCacheObject(redisKey, count); |
| | | } |
| | | } |
| | | |
| | | } |
| | | |
| | | /** |
| | | * 根据模型名称前缀调用不同的处理逻辑 |
| | | */ |
| | | private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) { |
| | | SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter); |
| | | String model = chatRequest.getModel(); |
| | | // 如果模型名称以ollama开头,则调用ollama中部署的本地模型 |
| | | if (model.startsWith("ollama-")) { |
| | |
| | | } else { |
| | | throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel()); |
| | | } |
| | | } else if (model.startsWith("gpt-4-gizmo")) { |
| | | chatRequest.setModel("gpt-4-gizmo"); |
| | | } else { |
| | | |
| | | if (model.startsWith("gpt-4-gizmo")) { |
| | | chatRequest.setModel("gpt-4-gizmo"); |
| | | } |
| | | ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel()); |
| | | //openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey()); |
| | | |
| | | ChatCompletion completion = ChatCompletion |
| | | .builder() |
| | | .messages(chatRequest.getMessages()) |
| | | .model(chatRequest.getModel()) |
| | | .temperature(0.2) |
| | | .topP(1.0) |
| | | .stream(true) |
| | | .build(); |
| | | openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); |
| | | |
| | | } |
| | | } |
| | | |
| | |
| | | * 构建消息列表 |
| | | */ |
| | | private void buildChatMessageList(ChatRequest chatRequest){ |
| | | ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel()); |
| | | // 获取对话消息列表 |
| | | List<Message> messages = chatRequest.getMessages(); |
| | | String sysPrompt = chatRequest.getSysPrompt(); |
| | | String sysPrompt = chatModelVo.getSystemPrompt(); |
| | | if(StringUtils.isEmpty(sysPrompt)){ |
| | | sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" + |
| | | "当前时间:"+ DateUtils.getDate(); |
| | |
| | | Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build(); |
| | | messages.add(0,sysMessage); |
| | | |
| | | chatRequest.setSysPrompt(sysPrompt); |
| | | // 查询向量库相关信息加入到上下文 |
| | | if(chatRequest.getKid()!=null){ |
| | | if(StringUtils.isNotEmpty(chatRequest.getKid())){ |
| | | List<Message> knMessages = new ArrayList<>(); |
| | | String content = messages.get(messages.size() - 1).getContent().toString(); |
| | | List<String> nearestList; |