package org.ruoyi.common.chat.utils; import cn.hutool.core.util.StrUtil; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingRegistry; import com.knuddels.jtokkit.api.EncodingType; import com.knuddels.jtokkit.api.ModelType; import lombok.extern.slf4j.Slf4j; import org.jetbrains.annotations.NotNull; import org.ruoyi.common.chat.entity.chat.ChatCompletion; import org.ruoyi.common.chat.entity.chat.FunctionCall; import org.ruoyi.common.chat.entity.chat.Message; import java.util.*; /** * token计算工具类 * * @author https:www.unfbx.com * @since 2023-04-04 */ @Slf4j public class TikTokensUtil { /** * 模型名称对应Encoding */ private static final Map modelMap = new HashMap<>(); /** * registry实例 */ private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); static { for (ModelType modelType : ModelType.values()) { modelMap.put(modelType.getName(), registry.getEncodingForModel(modelType)); } modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO)); modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO)); modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO)); modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_0125.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO)); modelMap.put(ChatCompletion.Model.GPT_4_32K.getName(), registry.getEncodingForModel(ModelType.GPT_4)); modelMap.put(ChatCompletion.Model.GPT_4_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4)); modelMap.put(ChatCompletion.Model.GPT_4_32K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4)); modelMap.put(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4)); modelMap.put(ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4)); modelMap.put(ChatCompletion.Model.GPT_4_0125_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4)); } /** * 通过Encoding和text获取编码数组 * * @param enc Encoding类型 * @param text 文本信息 * @return 编码数组 */ public static List encode(@NotNull Encoding enc, String text) { return StrUtil.isBlank(text) ? new ArrayList<>() : enc.encode(text); } /** * 通过Encoding计算text信息的tokens * * @param enc Encoding类型 * @param text 文本信息 * @return tokens数量 */ public static int tokens(@NotNull Encoding enc, String text) { return encode(enc, text).size(); } /** * 通过Encoding和encoded数组反推text信息 * * @param enc Encoding * @param encoded 编码数组 * @return 编码数组对应的文本信息 */ public static String decode(@NotNull Encoding enc, @NotNull List encoded) { return enc.decode(encoded); } /** * 获取一个Encoding对象,通过Encoding类型 * * @param encodingType encodingType * @return Encoding */ public static Encoding getEncoding(@NotNull EncodingType encodingType) { return registry.getEncoding(encodingType); } /** * 获取encode的编码数组 * * @param text 文本信息 * @return 编码数组 */ public static List encode(@NotNull EncodingType encodingType, String text) { if (StrUtil.isBlank(text)) { return new ArrayList<>(); } Encoding enc = getEncoding(encodingType); return enc.encode(text); } /** * 计算指定字符串的tokens,通过EncodingType * * @param encodingType encodingType * @param text 文本信息 * @return tokens数量 */ public static int tokens(@NotNull EncodingType encodingType, String text) { return encode(encodingType, text).size(); } /** * 通过EncodingType和encoded编码数组,反推字符串文本 * * @param encodingType encodingType * @param encoded 编码数组 * @return 编码数组对应的字符串 */ public static String decode(@NotNull EncodingType encodingType, @NotNull List encoded) { Encoding enc = getEncoding(encodingType); return enc.decode(encoded); } /** * 获取一个Encoding对象,通过模型名称 * * @param modelName 模型名称 * @return Encoding */ public static Encoding getEncoding(@NotNull String modelName) { return modelMap.getOrDefault(modelName, modelMap.get(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())); } /** * 获取encode的编码数组,通过模型名称 * * @param text 文本信息 * @return 编码数组 */ public static List encode(@NotNull String modelName, String text) { if (StrUtil.isBlank(text)) { return new ArrayList<>(); } Encoding enc = getEncoding(modelName); if (Objects.isNull(enc)) { log.warn("[{}]模型不存在或者暂不支持计算tokens,直接返回tokens==0",modelName); return new ArrayList<>(); } return enc.encode(text); } /** * 通过模型名称, 计算指定字符串的tokens * * @param modelName 模型名称 * @param text 文本信息 * @return tokens数量 */ public static int tokens(@NotNull String modelName, String text) { return encode(modelName, text).size(); } /** * 通过模型名称计算messages获取编码数组 * 参考官方的处理逻辑: * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * * @param modelName 模型名称 * @param messages 消息体 * @return tokens数量 */ public static int tokens(@NotNull String modelName, @NotNull List messages) { Encoding encoding = getEncoding(modelName); int tokensPerMessage = 0; int tokensPerName = 0; if (modelName.equals(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName()) || modelName.equals(ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName()) || modelName.equals(ChatCompletion.Model.GPT_4_0613.getName()) || modelName.equals(ChatCompletion.Model.GPT_4_32K_0613.getName()) || modelName.equals(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) || modelName.equals(ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName()) ) { tokensPerMessage = 3; tokensPerName = 1; }else if(modelName.contains(ChatCompletion.Model.GPT_3_5_TURBO.getName())){ //"gpt-3.5-turbo" in model: log.warn("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."); tokensPerMessage = 3; tokensPerName = 1; }else if(modelName.contains(ChatCompletion.Model.GPT_4.getName())){ log.warn("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."); tokensPerMessage = 3; tokensPerName = 1; }else { log.warn("不支持的model {} 按gpt4计算tokens",modelName); tokensPerMessage = 3; tokensPerName = 1; } int sum = 0; for (Message msg : messages) { sum += tokensPerMessage; sum += tokens(encoding, msg.getContent().toString()); sum += tokens(encoding, msg.getRole()); sum += tokens(encoding, msg.getName()); FunctionCall functionCall = msg.getFunctionCall(); sum += Objects.isNull(functionCall) ? 0 : tokens(encoding, functionCall.toString()); if (StrUtil.isNotBlank(msg.getName())) { sum += tokensPerName; } } sum += 3; return sum; } /** * 通过模型名称和encoded编码数组,反推字符串文本 * * @param modelName 模型名 * @param encoded 编码数组 * @return 返回源文本 */ public static String decode(@NotNull String modelName, @NotNull List encoded) { Encoding enc = getEncoding(modelName); return enc.decode(encoded); } }