/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.tools.FunctionDefinition;
import com.alibaba.dashscope.tools.ToolCallBase;
import com.alibaba.dashscope.tools.ToolCallFunction;
import com.alibaba.dashscope.utils.ApiKeywords;
import com.alibaba.dashscope.utils.JsonUtils;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
* backed by {@link Generation}.
*
* @author yuluo
* @author yuluo
* @since 2023.0.1.0
* @see ChatModel
* @see com.alibaba.dashscope.aigc.generation
*/
public class TongYiChatModel extends
AbstractFunctionCallSupport<
com.alibaba.dashscope.common.Message,
ConversationParam,
GenerationResult>
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(TongYiChatModel.class);
/**
* DashScope generation client.
*/
private final Generation generation;
/**
* The TongYi models default chat completion api.
*/
private TongYiChatOptions defaultOptions;
/**
* User role message manager.
*/
@Autowired
private MessageManager msgManager;
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
*/
public TongYiChatModel(Generation generation) {
this(generation,
TongYiChatOptions.builder()
.withTopP(0.8)
.withEnableSearch(true)
.withResultFormat(ConversationParam.ResultFormat.MESSAGE)
.build(),
null
);
}
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
* @param options TongYi model params.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options) {
this(generation, options, null);
}
/**
* Create a TongYi models client.
* @param generation DashScope model generation client.
* @param options TongYi default chat completion api.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
this.generation = generation;
this.defaultOptions = options;
}
/**
* Get default sca chat options.
*
* @return TongYiChatOptions default object.
*/
public TongYiChatOptions getDefaultOptions() {
return this.defaultOptions;
}
@Override
public ChatResponse call(Prompt prompt) {
ConversationParam params = toTongYiChatParams(prompt);
// TongYi models context loader.
com.alibaba.dashscope.common.Message message = new com.alibaba.dashscope.common.Message();
message.setRole(Role.USER.getValue());
message.setContent(prompt.getContents());
msgManager.add(message);
params.setMessages(msgManager.get());
logger.trace("TongYi ConversationOptions: {}", params);
GenerationResult chatCompletions = this.callWithFunctionSupport(params);
logger.trace("TongYi ConversationOptions: {}", params);
msgManager.add(chatCompletions);
List generations =
chatCompletions
.getOutput()
.getChoices()
.stream()
.map(choice ->
new org.springframework.ai.chat.model.Generation(
choice
.getMessage()
.getContent()
).withGenerationMetadata(generateChoiceMetadata(choice)
))
.toList();
return new ChatResponse(generations);
}
@Override
public Flux stream(Prompt prompt) {
Flowable genRes;
ConversationParam tongYiChatParams = toTongYiChatParams(prompt);
// See https://help.aliyun.com/zh/dashscope/developer-reference/api-details?spm=a2c4g.11186623.0.0.655fc11aRR0jj7#b9ad0a10cfhpe
// tongYiChatParams.setIncrementalOutput(true);
try {
genRes = generation.streamCall(tongYiChatParams);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes)
.flatMap(
message -> Flux.just(
message.getOutput()
.getChoices()
.get(0)
.getMessage()
.getContent())
.map(content -> {
var gen = new org.springframework.ai.chat.model.Generation(content)
.withGenerationMetadata(generateChoiceMetadata(
message.getOutput()
.getChoices()
.get(0)
));
return new ChatResponse(ListUtil.of(gen));
})
)
.publishOn(Schedulers.parallel());
}
/**
* Configuration properties to Qwen model params.
* Test access.
*
* @param prompt {@link Prompt}
* @return Qwen models params {@link ConversationParam}
*/
public ConversationParam toTongYiChatParams(Prompt prompt) {
Set functionsForThisRequest = new HashSet<>();
List tongYiMessage = prompt.getInstructions().stream()
.map(this::fromSpringAIMessage)
.toList();
ConversationParam chatParams = ConversationParam.builder()
.messages(tongYiMessage)
// models setting
// {@link HalfDuplexServiceParam#models}
.model(Generation.Models.QWEN_TURBO)
// {@link GenerationOutput}
.resultFormat(ConversationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build();
if (this.defaultOptions != null) {
chatParams = merge(chatParams, this.defaultOptions);
Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, !IS_RUNTIME_CALL);
functionsForThisRequest.addAll(defaultEnabledFunctions);
}
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
TongYiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, TongYiChatOptions.class);
chatParams = merge(updatedRuntimeOptions, chatParams);
Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ConversationParam:"
+ prompt.getOptions().getClass().getSimpleName());
}
}
// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List tools = this.getFunctionTools(functionsForThisRequest);
// todo chatParams.setTools(tools)
}
return chatParams;
}
private ChatGenerationMetadata generateChoiceMetadata(GenerationOutput.Choice choice) {
return ChatGenerationMetadata.from(
String.valueOf(choice.getFinishReason()),
choice.getMessage().getContent()
);
}
private List getFunctionTools(Set functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
FunctionDefinition functionDefinition = FunctionDefinition.builder()
.name(functionCallback.getName())
.description(functionCallback.getDescription())
.parameters(JsonUtils.parametersToJsonObject(
ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())
))
.build();
return functionDefinition;
}).toList();
}
private ConversationParam merge(ConversationParam tongYiParams, TongYiChatOptions scaChatParams) {
if (scaChatParams == null) {
return tongYiParams;
}
return ConversationParam.builder()
.messages(tongYiParams.getMessages())
.maxTokens((tongYiParams.getMaxTokens() != null) ? tongYiParams.getMaxTokens() : scaChatParams.getMaxTokens())
// When merge options. Because ConversationParams is must not null. So is setting.
.model(scaChatParams.getModel())
.resultFormat((tongYiParams.getResultFormat() != null) ? tongYiParams.getResultFormat() : scaChatParams.getResultFormat())
.enableSearch((tongYiParams.getEnableSearch() != null) ? tongYiParams.getEnableSearch() : scaChatParams.getEnableSearch())
.topK((tongYiParams.getTopK() != null) ? tongYiParams.getTopK() : scaChatParams.getTopK())
.topP((tongYiParams.getTopP() != null) ? tongYiParams.getTopP() : scaChatParams.getTopP())
.incrementalOutput((tongYiParams.getIncrementalOutput() != null) ? tongYiParams.getIncrementalOutput() : scaChatParams.getIncrementalOutput())
.temperature((tongYiParams.getTemperature() != null) ? tongYiParams.getTemperature() : scaChatParams.getTemperature())
.repetitionPenalty((tongYiParams.getRepetitionPenalty() != null) ? tongYiParams.getRepetitionPenalty() : scaChatParams.getRepetitionPenalty())
.seed((tongYiParams.getSeed() != null) ? tongYiParams.getSeed() : scaChatParams.getSeed())
.build();
}
private ConversationParam merge(TongYiChatOptions scaChatParams, ConversationParam tongYiParams) {
if (scaChatParams == null) {
return tongYiParams;
}
ConversationParam mergedTongYiParams = ConversationParam.builder()
.model(Generation.Models.QWEN_TURBO)
.messages(tongYiParams.getMessages())
.build();
mergedTongYiParams = merge(tongYiParams, scaChatParams);
if (scaChatParams.getMaxTokens() != null) {
mergedTongYiParams.setMaxTokens(scaChatParams.getMaxTokens());
}
if (scaChatParams.getStop() != null) {
mergedTongYiParams.setStopStrings(scaChatParams.getStop());
}
if (scaChatParams.getTemperature() != null) {
mergedTongYiParams.setTemperature(scaChatParams.getTemperature());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
return mergedTongYiParams;
}
private com.alibaba.dashscope.common.Message fromSpringAIMessage(Message message) {
return switch (message.getMessageType()) {
case USER -> com.alibaba.dashscope.common.Message.builder()
.role(Role.USER.getValue())
.content(message.getContent())
.build();
case SYSTEM -> com.alibaba.dashscope.common.Message.builder()
.role(Role.SYSTEM.getValue())
.content(message.getContent())
.build();
case ASSISTANT -> com.alibaba.dashscope.common.Message.builder()
.role(Role.ASSISTANT.getValue())
.content(message.getContent())
.build();
default -> throw new IllegalArgumentException("Unknown message type " + message.getMessageType());
};
}
@Override
protected ConversationParam doCreateToolResponseRequest(
ConversationParam previousRequest,
com.alibaba.dashscope.common.Message responseMessage,
List conversationHistory
) {
for (ToolCallBase toolCall : responseMessage.getToolCalls()) {
if (toolCall instanceof ToolCallFunction toolCallFunction) {
if (toolCallFunction.getFunction() != null) {
var functionName = toolCallFunction.getFunction().getName();
var functionArguments = toolCallFunction.getFunction().getArguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
// Add the function response to the conversation.
conversationHistory
.add(com.alibaba.dashscope.common.Message.builder()
.content(functionResponse)
.role(Role.BOT.getValue())
.toolCallId(toolCall.getId())
.build()
);
}
}
}
ConversationParam newRequest = ConversationParam.builder().messages(conversationHistory).build();
// todo: No @JsonProperty fields.
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ConversationParam.class);
return newRequest;
}
@Override
protected List doGetUserMessages(ConversationParam request) {
return request.getMessages();
}
@Override
protected com.alibaba.dashscope.common.Message doGetToolResponseMessage(GenerationResult response) {
var message = response.getOutput().getChoices().get(0).getMessage();
var assistantMessage = com.alibaba.dashscope.common.Message.builder().role(Role.ASSISTANT.getValue())
.content("").build();
assistantMessage.setToolCalls(message.getToolCalls());
return assistantMessage;
}
@Override
protected GenerationResult doChatCompletion(ConversationParam request) {
GenerationResult result;
try {
result = generation.call(request);
}
catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
return result;
}
@Override
protected Flux doChatCompletionStream(ConversationParam request) {
final Flowable genRes;
try {
genRes = generation.streamCall(request);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes);
}
@Override
protected boolean isToolFunctionCall(GenerationResult response) {
if (response == null || CollectionUtils.isEmpty(response.getOutput().getChoices())) {
return false;
}
var choice = response.getOutput().getChoices().get(0);
if (choice == null || choice.getFinishReason() == null) {
return false;
}
return Objects.equals(choice.getFinishReason(), ApiKeywords.TOOL_CALLS);
}
}