package com.xmzs.common.chat.utils;
|
|
import cn.hutool.core.util.StrUtil;
|
import com.knuddels.jtokkit.Encodings;
|
import com.knuddels.jtokkit.api.Encoding;
|
import com.knuddels.jtokkit.api.EncodingRegistry;
|
import com.knuddels.jtokkit.api.EncodingType;
|
import com.knuddels.jtokkit.api.ModelType;
|
import com.xmzs.common.chat.entity.chat.BaseChatCompletion;
|
import lombok.extern.slf4j.Slf4j;
|
|
import com.xmzs.common.chat.entity.chat.ChatCompletion;
|
import com.xmzs.common.chat.entity.chat.FunctionCall;
|
import com.xmzs.common.chat.entity.chat.Message;
|
import org.jetbrains.annotations.NotNull;
|
|
import java.util.*;
|
|
/**
|
* 描述:token计算工具类
|
*
|
* @author https:www.unfbx.com
|
* @since 2023-04-04
|
*/
|
@Slf4j
|
public class TikTokensUtil {
|
/**
|
* 模型名称对应Encoding
|
*/
|
private static final Map<String, Encoding> modelMap = new HashMap<>();
|
/**
|
* registry实例
|
*/
|
private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
|
|
static {
|
for (ModelType modelType : ModelType.values()) {
|
modelMap.put(modelType.getName(), registry.getEncodingForModel(modelType));
|
}
|
modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_0125.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
|
modelMap.put(ChatCompletion.Model.GPT_4_32K.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
modelMap.put(ChatCompletion.Model.GPT_4_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
modelMap.put(ChatCompletion.Model.GPT_4_32K_0613.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
modelMap.put(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
modelMap.put(ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
modelMap.put(ChatCompletion.Model.GPT_4_0125_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
|
}
|
|
/**
|
* 通过Encoding和text获取编码数组
|
*
|
* @param enc Encoding类型
|
* @param text 文本信息
|
* @return 编码数组
|
*/
|
public static List<Integer> encode(@NotNull Encoding enc, String text) {
|
return StrUtil.isBlank(text) ? new ArrayList<>() : enc.encode(text);
|
}
|
|
/**
|
* 通过Encoding计算text信息的tokens
|
*
|
* @param enc Encoding类型
|
* @param text 文本信息
|
* @return tokens数量
|
*/
|
public static int tokens(@NotNull Encoding enc, String text) {
|
return encode(enc, text).size();
|
}
|
|
|
/**
|
* 通过Encoding和encoded数组反推text信息
|
*
|
* @param enc Encoding
|
* @param encoded 编码数组
|
* @return 编码数组对应的文本信息
|
*/
|
public static String decode(@NotNull Encoding enc, @NotNull List<Integer> encoded) {
|
return enc.decode(encoded);
|
}
|
|
/**
|
* 获取一个Encoding对象,通过Encoding类型
|
*
|
* @param encodingType encodingType
|
* @return Encoding
|
*/
|
public static Encoding getEncoding(@NotNull EncodingType encodingType) {
|
return registry.getEncoding(encodingType);
|
}
|
|
/**
|
* 获取encode的编码数组
|
*
|
* @param text 文本信息
|
* @return 编码数组
|
*/
|
public static List<Integer> encode(@NotNull EncodingType encodingType, String text) {
|
if (StrUtil.isBlank(text)) {
|
return new ArrayList<>();
|
}
|
Encoding enc = getEncoding(encodingType);
|
return enc.encode(text);
|
}
|
|
/**
|
* 计算指定字符串的tokens,通过EncodingType
|
*
|
* @param encodingType encodingType
|
* @param text 文本信息
|
* @return tokens数量
|
*/
|
public static int tokens(@NotNull EncodingType encodingType, String text) {
|
return encode(encodingType, text).size();
|
}
|
|
|
/**
|
* 通过EncodingType和encoded编码数组,反推字符串文本
|
*
|
* @param encodingType encodingType
|
* @param encoded 编码数组
|
* @return 编码数组对应的字符串
|
*/
|
public static String decode(@NotNull EncodingType encodingType, @NotNull List<Integer> encoded) {
|
Encoding enc = getEncoding(encodingType);
|
return enc.decode(encoded);
|
}
|
|
|
/**
|
* 获取一个Encoding对象,通过模型名称
|
*
|
* @param modelName 模型名称
|
* @return Encoding
|
*/
|
public static Encoding getEncoding(@NotNull String modelName) {
|
return modelMap.get(modelName);
|
}
|
|
/**
|
* 获取encode的编码数组,通过模型名称
|
*
|
* @param text 文本信息
|
* @return 编码数组
|
*/
|
public static List<Integer> encode(@NotNull String modelName, String text) {
|
if (StrUtil.isBlank(text)) {
|
return new ArrayList<>();
|
}
|
Encoding enc = getEncoding(modelName);
|
if (Objects.isNull(enc)) {
|
log.warn("[{}]模型不存在或者暂不支持计算tokens,直接返回tokens==0",modelName);
|
return new ArrayList<>();
|
}
|
return enc.encode(text);
|
}
|
|
/**
|
* 通过模型名称, 计算指定字符串的tokens
|
*
|
* @param modelName 模型名称
|
* @param text 文本信息
|
* @return tokens数量
|
*/
|
public static int tokens(@NotNull String modelName, String text) {
|
return encode(modelName, text).size();
|
}
|
|
|
/**
|
* 通过模型名称计算messages获取编码数组
|
* 参考官方的处理逻辑:
|
* <a href=https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb</a>
|
*
|
* @param modelName 模型名称
|
* @param messages 消息体
|
* @return tokens数量
|
*/
|
public static int tokens(@NotNull String modelName, @NotNull List<Message> messages) {
|
Encoding encoding = getEncoding(modelName);
|
int tokensPerMessage = 0;
|
int tokensPerName = 0;
|
if (modelName.equals(ChatCompletion.Model.GPT_3_5_TURBO_0613.getName())
|
|| modelName.equals(ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName())
|
|| modelName.equals(ChatCompletion.Model.GPT_4_0613.getName())
|
|| modelName.equals(ChatCompletion.Model.GPT_4_32K_0613.getName())
|
|| modelName.equals(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
|| modelName.equals(ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName())
|
) {
|
tokensPerMessage = 3;
|
tokensPerName = 1;
|
}else if(modelName.contains(ChatCompletion.Model.GPT_3_5_TURBO.getName())){
|
//"gpt-3.5-turbo" in model:
|
log.warn("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.");
|
tokensPerMessage = 3;
|
tokensPerName = 1;
|
}else if(modelName.contains(ChatCompletion.Model.GPT_4.getName())){
|
log.warn("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.");
|
tokensPerMessage = 3;
|
tokensPerName = 1;
|
}else {
|
log.warn("不支持的model {}. See https://github.com/openai/openai-python/blob/main/chatml.md 更多信息.",modelName);
|
}
|
int sum = 0;
|
for (Message msg : messages) {
|
sum += tokensPerMessage;
|
sum += tokens(encoding, msg.getContent());
|
sum += tokens(encoding, msg.getRole());
|
sum += tokens(encoding, msg.getName());
|
FunctionCall functionCall = msg.getFunctionCall();
|
sum += Objects.isNull(functionCall) ? 0 : tokens(encoding, functionCall.toString());
|
if (StrUtil.isNotBlank(msg.getName())) {
|
sum += tokensPerName;
|
}
|
}
|
sum += 3;
|
return sum;
|
}
|
|
/**
|
* 通过模型名称和encoded编码数组,反推字符串文本
|
*
|
* @param modelName 模型名
|
* @param encoded 编码数组
|
* @return 返回源文本
|
*/
|
public static String decode(@NotNull String modelName, @NotNull List<Integer> encoded) {
|
Encoding enc = getEncoding(modelName);
|
return enc.decode(encoded);
|
}
|
|
}
|