办学质量监测教学评价系统
ageerle
2025-05-26 abcde9e36e8cb6e80bc091532d6e11789b9eb085
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package org.ruoyi.chat.service.chat.impl;
 
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.enums.BillingType;
import org.ruoyi.chat.enums.UserGradeType;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.utils.TikTokensUtil;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.domain.ChatUsageToken;
import org.ruoyi.domain.bo.ChatMessageBo;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatMessageService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatTokenService;
import org.ruoyi.system.domain.SysUser;
import org.ruoyi.system.mapper.SysUserMapper;
import org.springframework.stereotype.Service;
 
 
/**
 * 计费管理Service实现
 *
 * @author ageerle
 * @date 2025-04-08
 */
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatCostServiceImpl implements IChatCostService {
 
    private final SysUserMapper sysUserMapper;
 
    private final IChatMessageService chatMessageService;
 
    private final IChatTokenService chatTokenService;
 
    private final IChatModelService chatModelService;
 
    /**
     * 扣除用户余额
     */
    @Override
    public void deductToken(ChatRequest chatRequest) {
        if(chatRequest.getUserId()==null || chatRequest.getSessionId()==null){
            return;
        }
 
        int tokens = TikTokensUtil.tokens(chatRequest.getModel(), chatRequest.getPrompt());
 
        String modelName = chatRequest.getModel();
 
        ChatMessageBo chatMessageBo = new ChatMessageBo();
 
        // 设置用户id
        chatMessageBo.setUserId(chatRequest.getUserId());
        // 设置对话角色
        chatMessageBo.setRole(chatRequest.getRole());
        // 设置会话id
        chatMessageBo.setSessionId(chatRequest.getSessionId());
        // 设置对话内容
        chatMessageBo.setContent(chatRequest.getPrompt());
 
        // 计算总token数
        ChatUsageToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), modelName);
        if (chatToken == null) {
            chatToken = new ChatUsageToken();
            chatToken.setToken(0);
        }
        int totalTokens = chatToken.getToken() + tokens;
        // 如果总token数大于等于1000,进行费用扣除
        if (totalTokens >= 1000) {
            // 计算费用
            int token1 = totalTokens / 1000;
            int token2 = totalTokens % 1000;
            if (token2 > 0) {
                // 保存剩余tokens
                chatToken.setModelName(modelName);
                chatToken.setUserId(chatMessageBo.getUserId());
                chatToken.setToken(token2);
                chatTokenService.editToken(chatToken);
            } else {
                chatTokenService.resetToken(chatMessageBo.getUserId(), modelName);
            }
            ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
            double cost = chatModelVo.getModelPrice();
            if (BillingType.TIMES.getCode().equals(chatModelVo.getModelType())) {
                // 按次数扣费
                deductUserBalance(chatMessageBo.getUserId(), cost);
                chatMessageBo.setDeductCost(cost);
            }else {
                // 按token扣费
                Double numberCost = token1 * cost;
                deductUserBalance(chatMessageBo.getUserId(), numberCost);
                chatMessageBo.setDeductCost(numberCost);
            }
 
        } else {
            deductUserBalance(chatMessageBo.getUserId(), 0.0);
            chatMessageBo.setDeductCost(0d);
            chatMessageBo.setRemark("不满1kToken,计入下一次!");
            chatToken.setToken(totalTokens);
            chatToken.setModelName(chatMessageBo.getModelName());
            chatToken.setUserId(chatMessageBo.getUserId());
            chatTokenService.editToken(chatToken);
        }
        // 保存消息记录
        chatMessageService.insertByBo(chatMessageBo);
    }
 
    /**
     * 从用户余额中扣除费用
     *
     * @param userId     用户ID
     * @param numberCost 要扣除的费用
     */
    @Override
    public void deductUserBalance(Long userId, Double numberCost) {
        SysUser sysUser = sysUserMapper.selectById(userId);
        if (sysUser == null) {
            return;
        }
 
        Double userBalance = sysUser.getUserBalance();
        if (userBalance < numberCost || userBalance == 0) {
            throw new ServiceException("余额不足, 请充值");
        }
        sysUserMapper.update(null,
            new LambdaUpdateWrapper<SysUser>()
                .set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
                .eq(SysUser::getUserId, userId));
    }
 
    /**
     * 扣除任务费用
     */
    @Override
    public void taskDeduct(String type,String prompt, double cost) {
        // 判断用户是否付费
        checkUserGrade();
        // 扣除费用
        deductUserBalance(getUserId(), cost);
        // 保存消息记录
        ChatMessageBo chatMessageBo = new ChatMessageBo();
        chatMessageBo.setUserId(getUserId());
        chatMessageBo.setModelName(type);
        chatMessageBo.setContent(prompt);
        chatMessageBo.setDeductCost(cost);
        chatMessageBo.setTotalTokens(0);
        chatMessageService.insertByBo(chatMessageBo);
    }
 
    /**
     * 判断用户是否付费
     */
    @Override
    public void checkUserGrade() {
        SysUser sysUser = sysUserMapper.selectById(getUserId());
        if(UserGradeType.UNPAID.getCode().equals(sysUser.getUserGrade())){
            throw new BaseException("该模型仅限付费用户使用。请升级套餐,开启高效体验之旅!");
        }
    }
 
    /**
     * 获取用户Id
     */
    public Long getUserId() {
        LoginUser loginUser = LoginHelper.getLoginUser();
        if (loginUser == null) {
            throw new BaseException("用户未登录!");
        }
        return loginUser.getUserId();
    }
}