package org.ruoyi.chat.service.chat.impl; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.ruoyi.chat.enums.BillingType; import org.ruoyi.chat.enums.UserGradeType; import org.ruoyi.chat.service.chat.IChatCostService; import org.ruoyi.common.chat.config.LocalCache; import org.ruoyi.common.chat.request.ChatRequest; import org.ruoyi.common.chat.utils.TikTokensUtil; import org.ruoyi.common.core.domain.model.LoginUser; import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.common.core.exception.base.BaseException; import org.ruoyi.common.satoken.utils.LoginHelper; import org.ruoyi.domain.ChatToken; import org.ruoyi.domain.bo.ChatMessageBo; import org.ruoyi.domain.vo.ChatModelVo; import org.ruoyi.service.IChatMessageService; import org.ruoyi.service.IChatModelService; import org.ruoyi.service.IChatTokenService; import org.ruoyi.system.domain.SysUser; import org.ruoyi.system.mapper.SysUserMapper; import org.springframework.stereotype.Service; /** * 计费管理Service实现 * * @author ageerle * @date 2025-04-08 */ @Slf4j @Service @RequiredArgsConstructor public class ChatCostServiceImpl implements IChatCostService { private final SysUserMapper sysUserMapper; private final IChatMessageService chatMessageService; private final IChatTokenService chatTokenService; private final IChatModelService chatModelService; /** * 扣除用户余额 */ public void deductToken(ChatRequest chatRequest) { int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt()); String modelName = chatRequest.getModel(); ChatMessageBo chatMessageBo = new ChatMessageBo(); Object userId = LocalCache.CACHE.get("userId"); if(userId!=null){ chatMessageBo.setUserId((Long) userId); }else { chatMessageBo.setUserId(getUserId()); } // 计算总token数 ChatToken chatToken = chatTokenService.queryByUserId(getUserId(), modelName); if (chatToken == null) { chatToken = new ChatToken(); chatToken.setToken(0); } int totalTokens = chatToken.getToken() + tokens; // 如果总token数大于等于1000,进行费用扣除 if (totalTokens >= 1000) { // 计算费用 int token1 = totalTokens / 1000; int token2 = totalTokens % 1000; if (token2 > 0) { // 保存剩余tokens chatToken.setModelName(modelName); chatToken.setUserId(getUserId()); chatToken.setToken(token2); chatTokenService.editToken(chatToken); } else { chatTokenService.resetToken(getUserId(), modelName); } ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName); double cost = chatModelVo.getModelPrice(); if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) { // 按次数扣费 deductUserBalance(getUserId(), cost); chatMessageBo.setDeductCost(cost); }else { // 按token扣费 Double numberCost = token1 * cost; deductUserBalance(chatMessageBo.getUserId(), numberCost); chatMessageBo.setDeductCost(numberCost); } chatMessageBo.setContent(chatRequest.getPrompt()); } else { deductUserBalance(getUserId(), 0.0); chatMessageBo.setDeductCost(0d); chatMessageBo.setRemark("不满1kToken,计入下一次!"); chatToken.setToken(totalTokens); chatToken.setModelName(chatMessageBo.getModelName()); chatToken.setUserId(chatMessageBo.getUserId()); chatTokenService.editToken(chatToken); } // 保存消息记录 chatMessageService.insertByBo(chatMessageBo); } /** * 从用户余额中扣除费用 * * @param userId 用户ID * @param numberCost 要扣除的费用 */ @Override public void deductUserBalance(Long userId, Double numberCost) { SysUser sysUser = sysUserMapper.selectById(userId); if (sysUser == null) { return; } Double userBalance = sysUser.getUserBalance(); if (userBalance < numberCost || userBalance == 0) { throw new ServiceException("余额不足, 请充值"); } sysUserMapper.update(null, new LambdaUpdateWrapper() .set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0)) .eq(SysUser::getUserId, userId)); } /** * 扣除任务费用 */ @Override public void taskDeduct(String type,String prompt, double cost) { // 判断用户是否付费 checkUserGrade(); // 扣除费用 deductUserBalance(getUserId(), cost); // 保存消息记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(type); chatMessageBo.setContent(prompt); chatMessageBo.setDeductCost(cost); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); } /** * 判断用户是否付费 */ @Override public void checkUserGrade() { SysUser sysUser = sysUserMapper.selectById(getUserId()); if(UserGradeType.UNPAID.getCode().equals(sysUser.getUserGrade())){ throw new BaseException("该模型仅限付费用户使用。请升级套餐,开启高效体验之旅!"); } } /** * 获取用户Id */ public Long getUserId() { LoginUser loginUser = LoginHelper.getLoginUser(); if (loginUser == null) { throw new BaseException("用户未登录!"); } return loginUser.getUserId(); } }