package com.xmzs.system.service.impl; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.json.JSONUtil; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.*; import com.azure.core.credential.AzureKeyCredential; import com.azure.core.util.IterableStream; import com.fasterxml.jackson.databind.ObjectMapper; import com.xmzs.common.chat.config.LocalCache; import com.xmzs.common.chat.constant.OpenAIConst; import com.xmzs.common.chat.domain.request.ChatRequest; import com.xmzs.common.chat.domain.request.Dall3Request; import com.xmzs.common.chat.entity.chat.*; import com.xmzs.common.chat.entity.images.Image; import com.xmzs.common.chat.entity.images.ImageResponse; import com.xmzs.common.chat.entity.images.Item; import com.xmzs.common.chat.entity.images.ResponseFormat; import com.xmzs.common.chat.openai.OpenAiStreamClient; import com.xmzs.common.chat.utils.TikTokensUtil; import com.xmzs.common.core.domain.model.LoginUser; import com.xmzs.common.core.exception.ServiceException; import com.xmzs.common.core.exception.base.BaseException; import com.xmzs.common.core.utils.StringUtils; import com.xmzs.common.satoken.utils.LoginHelper; import com.xmzs.common.translation.annotation.Translation; import com.xmzs.system.domain.SysUser; import com.xmzs.system.domain.bo.ChatMessageBo; import com.xmzs.system.listener.SSEEventSourceListener; import com.xmzs.system.mapper.SysUserMapper; import com.xmzs.system.service.ChatService; import com.xmzs.system.service.IChatMessageService; import com.xmzs.system.service.SseService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.MediaType; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.core.models.ResponseError; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.ImageGenerationData; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageGenerations; import com.azure.core.credential.AzureKeyCredential; /** * 描述: * * @author https:www.unfbx.com * @date 2023-04-08 */ @Service @Slf4j @RequiredArgsConstructor public class SseServiceImpl implements SseService { private final OpenAiStreamClient openAiStreamClient; private final ChatService chatService; private final SysUserMapper sysUserMapper; private final IChatMessageService chatMessageService; @Value("${transit.apiKey}") private String API_KEY; @Value("${transit.apiHost}") private String API_HOST; private static final String DONE_SIGNAL = "[DONE]"; @Override @Transactional public SseEmitter sseChat(ChatRequest chatRequest) { LocalCache.CACHE.put("userId",getUserId()); SseEmitter sseEmitter = new SseEmitter(0L); SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter); checkUserGrade(sseEmitter, chatRequest.getModel()); // 获取对话消息列表 List msgList = chatRequest.getMessages(); // 图文识别上下文信息 List contentList = chatRequest.getContent(); // 图文识别模型 if (ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName().equals(chatRequest.getModel())) { MessagePicture message = MessagePicture.builder().role(Message.Role.USER.getName()).content(contentList).build(); ChatCompletionWithPicture chatCompletion = ChatCompletionWithPicture .builder() .messages(Collections.singletonList(message)) .model(chatRequest.getModel()) .temperature(chatRequest.getTemperature()) .topP(chatRequest.getTop_p()) .stream(true) .build(); openAiStreamClient.streamChatCompletion(chatCompletion, openAIEventSourceListener); // 扣除图文对话费用 chatService.deductUserBalance(getUserId(),OpenAIConst.GPT4_COST); String text = contentList.get(contentList.size() - 1).getText(); // 保存消息记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(chatRequest.getModel()); chatMessageBo.setContent(text); chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); } else { ChatCompletion completion = ChatCompletion .builder() .messages(msgList) .model(chatRequest.getModel()) .temperature(chatRequest.getTemperature()) .topP(chatRequest.getTop_p()) .stream(true) .build(); openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); Message message = msgList.get(msgList.size() - 1); // 扣除余额 int tokens = TikTokensUtil.tokens(chatRequest.getModel(), msgList); ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(chatRequest.getModel()); chatMessageBo.setContent(message.getContent()); chatMessageBo.setTotalTokens(tokens); chatService.deductToken(chatMessageBo); } return sseEmitter; } /** * dall-e-3绘画接口 * * @param request * @return */ public List dall3(Dall3Request request) { checkUserGrade(null,""); // DALL3 绘图模型 Image image = Image.builder() .responseFormat(ResponseFormat.URL.getName()) .model(Image.Model.DALL_E_3.getName()) .prompt(request.getPrompt()) .n(1) .quality(request.getQuality()) .size(request.getSize()) .style(request.getStyle()) .build(); ImageResponse imageResponse = openAiStreamClient.genImages(image); // 扣除费用 if(Objects.equals(request.getSize(), "1792x1024") || Objects.equals(request.getSize(), "1024x1792")){ chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_HD_COST); }else { chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_COST); } // 保存扣费记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(Image.Model.DALL_E_3.getName()); chatMessageBo.setContent(request.getPrompt()); chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); return imageResponse.getData(); } @Override public void mjTask() { // 检验是否是免费用户 checkUserGrade(null,""); chatService.deductUserBalance(getUserId(),OpenAIConst.MJ_COST); // 保存扣费记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName("mj"); chatMessageBo.setContent("mj绘图"); chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); } /** * 中转接口 * * @param chatRequest * @return */ @Override public SseEmitter transitChat(ChatRequest chatRequest) { // 获取对话消息列表 List msgList = chatRequest.getMessages(); Message message = msgList.get(msgList.size() - 1); SseEmitter emitter = new SseEmitter(0L); checkUserGrade(emitter, chatRequest.getModel()); ChatCompletion completion = ChatCompletion .builder() .messages(chatRequest.getMessages()) .model(chatRequest.getModel()) .temperature(chatRequest.getTemperature()) .topP(chatRequest.getTop_p()) .stream(true) .build(); // 启动一个新的线程来处理数据流 new Thread(() -> { // 启动一个新的线程来处理数据流 try { ObjectMapper mapper = new ObjectMapper(); String requestBody = mapper.writeValueAsString(completion); HttpRequest request = HttpRequest.newBuilder() .uri(URI.create(API_HOST + "v1/chat/completions")) .header("Authorization", "Bearer " + API_KEY) .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(requestBody)) .build(); // 发送请求并获取响应体作为InputStream HttpResponse response = HttpClient.newHttpClient().send(request, HttpResponse.BodyHandlers.ofInputStream()); // 使用正确的字符编码将InputStream包装为InputStreamReader,然后创建BufferedReader BufferedReader reader = new BufferedReader(new InputStreamReader(response.body())); String line; while ((line = reader.readLine()) != null) { if (line.startsWith("data: ")) { String data = line.replace("data: ", ""); emitter.send(data, MediaType.TEXT_PLAIN); if (data.equals(DONE_SIGNAL)) { //成功响应 emitter.complete(); } } } // 关闭资源 reader.close(); } catch (Exception e) { emitter.complete(); throw new ServiceException("调用中转接口失败:"+e.getMessage()); } }).start(); chatService.deductUserBalance(getUserId(),OpenAIConst.GPT4_COST); // 保存消息记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(chatRequest.getModel()); chatMessageBo.setContent(message.getContent()); chatMessageBo.setDeductCost(OpenAIConst.GPT4_COST); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); return emitter; } public static void main(String[] args) { String azureOpenaiKey = "-"; String endpoint = "-"; String deploymentOrModelName = "-"; OpenAIClient client = new OpenAIClientBuilder() .endpoint(endpoint) .credential(new AzureKeyCredential(azureOpenaiKey)) .buildClient(); ImageGenerationOptions imageGenerationOptions = new ImageGenerationOptions( "A drawing of the Seattle skyline in the style of Van Gogh"); ImageGenerations images = client.getImageGenerations(deploymentOrModelName, imageGenerationOptions); for (ImageGenerationData imageGenerationData : images.getData()) { System.out.printf( "Image location URL that provides temporary access to download the generated image is %s.%n", imageGenerationData.getUrl()); } } public SseEmitter azureChat(ChatRequest chatRequest) { String azureOpenaiKey = "-"; String endpoint = "-"; String deploymentOrModelId = "-"; OpenAIClient client = new OpenAIClientBuilder() .endpoint(endpoint) .credential(new AzureKeyCredential(azureOpenaiKey)) .buildClient(); final SseEmitter emitter = new SseEmitter(); // 使用线程池异步执行 ExecutorService service = Executors.newSingleThreadExecutor(); service.execute(() -> { try { // 获取对话消息列表 List chatMessages = chatRequest.getMessages(); List messages = new ArrayList<>(); chatMessages.forEach( e->{ ChatRequestMessage chatMessage; if(Message.Role.SYSTEM.getName().equals(e.getRole())){ chatMessage = new ChatRequestSystemMessage(e.getContent()); }else { chatMessage = new ChatRequestUserMessage(e.getContent()); } messages.add(chatMessage); } ); // 获取流式响应 IterableStream chatCompletionsStream = client.getChatCompletionsStream(deploymentOrModelId, new ChatCompletionsOptions(messages)); // 遍历流式响应并发送到客户端 for (ChatCompletions chatCompletion : chatCompletionsStream) { if(CollectionUtil.isEmpty(chatCompletion.getChoices())){ continue; } log.info("json ======{}", JSONUtil.toJsonStr(chatCompletion)); emitter.send(chatCompletion); } emitter.complete(); } catch (Exception e) { emitter.completeWithError(e); } }); return emitter; } /** * 判断用户是否付费 */ public void checkUserGrade(SseEmitter emitter, String model) { SysUser sysUser = sysUserMapper.selectById(getUserId()); if(StringUtils.isEmpty(model)){ if("0".equals(sysUser.getUserGrade())){ throw new ServiceException("免费用户暂时不支持此模型,请切换gpt-3.5-turbo模型或者点击《进入市场选购您的商品》充值后使用!",500); } } // TODO 添加枚举 if ("0".equals(sysUser.getUserGrade()) && !ChatCompletion.Model.GPT_3_5_TURBO.getName().equals(model)) { // 创建并发送一个名为 "error" 的事件,带有错误消息和状态码 SseEmitter.SseEventBuilder event = SseEmitter.event() .name("error") // 客户端将监听这个事件名 .data("免费用户暂时不支持此模型,请切换gpt-3.5-turbo模型或者点击《进入市场选购您的商品》充值后使用!"); try { emitter.send(event); } catch (IOException e) { throw new RuntimeException(e); } emitter.complete(); } } /** * 获取用户Id * * @return */ public Long getUserId(){ LoginUser loginUser = LoginHelper.getLoginUser(); if (loginUser == null) { throw new BaseException("用户未登录!"); } return loginUser.getUserId(); } }