/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
import java.util.*;
/**
* @author yuluo
* @author yuluo
* @since 2023.0.1.0
*/
public class TongYiChatOptions implements FunctionCallingOptions, ChatOptions {
/**
* TongYi Models.
* {@link Generation.Models}
*/
private String model = Generation.Models.QWEN_TURBO;
/**
* The random number seed used in generation, the user controls the randomness of the content generated by the model.
* seed supports unsigned 64-bit integers, with a default value of 1234.
* when using seed, the model will generate the same or similar results as much as possible, but there is currently no guarantee that the results will be exactly the same each time.
*/
private Integer seed = 1234;
/**
* Used to specify the maximum number of tokens that the model can generate when generating content,
* it defines the upper limit of generation but does not guarantee that this number will be generated every time.
* For qwen-turbo the maximum and default values are 1500 tokens.
* The qwen-max, qwen-max-1201, qwen-max-longcontext, and qwen-plus models have a maximum and default value of 2000 tokens.
*/
private Integer maxTokens = 1500;
/**
* The generation process kernel sampling method probability threshold,
* for example, takes the value of 0.8, only retains the smallest set of the most probable tokens with probabilities that add up to greater than or equal to 0.8 as the candidate set.
* The range of values is (0,1.0), the larger the value, the higher the randomness of generation; the lower the value, the higher the certainty of generation.
*/
private Double topP = 0.8;
/**
* The size of the sampling candidate set at the time of generation.
* For example, with a value of 50, only the 50 highest scoring tokens in a single generation will form a randomly sampled candidate set.
* The larger the value, the higher the randomness of the generation; the smaller the value, the higher the certainty of the generation.
* This parameter is not passed by default, and a value of None or when top_k is greater than 100 indicates that the top_k policy is not enabled,
* at which time, only the top_p policy is in effect.
*/
private Integer topK;
/**
* Used to control the repeatability of model generation.
* Increasing repetition_penalty reduces the repetition of model generation. 1.0 means no penalty.
*/
private Double repetitionPenalty = 1.1;
/**
* is used to control the degree of randomness and diversity.
* Specifically, the temperature value controls the extent to which the probability distribution of each candidate word is smoothed when generating text.
* Higher values of temperature reduce the peak of the probability distribution, allowing more low-probability words to be selected and generating more diverse results,
* while lower values of temperature enhance the peak of the probability distribution, making it easier for high-probability words to be selected and generating more certain results.
* Range: [0, 2), 0 is not recommended, meaningless.
* java version >= 2.5.1
*/
private Double temperature = 0.85;
/**
* The stop parameter is used to realize precise control of the content generation process, automatically stopping when the generated content is about to contain the specified string or token_ids,
* and the generated content does not contain the specified content.
* For example, if stop is specified as "Hello", it means stop when "Hello" will be generated; if stop is specified as [37763, 367], it means stop when "Observation" will be generated.
* The stop parameter can be passed as a list of arrays of strings or token_ids to support the scenario of using multiple stops.
* Explanation: Do not mix strings and token_ids in list mode, the element types should be the same in list mode.
*/
private List stop;
/**
* Whether or not to use stream output. When outputting the result in stream mode, the interface returns the result as generator,
* you need to iterate to get the result, the default output is the whole sequence of the current generation for each output,
* the last output is the final result of all the generation, you can change the output mode to non-incremental output by the parameter incremental_output to False.
*/
private Boolean stream = false;
/**
* The model has a built-in Internet search service.
* This parameter controls whether the model refers to the use of Internet search results when generating text. The values are as follows:
* True: enable internet search, the model will use the search result as the reference information in the text generation process, but the model will "judge by itself" whether to use the internet search result based on its internal logic.
* False (default): Internet search is disabled.
*/
private Boolean enableSearch = false;
/**
* [text|message], defaults to text, when it is message,
* the output refers to the message result example.
* It is recommended to prioritize the use of message format.
*/
private String resultFormat = GenerationParam.ResultFormat.MESSAGE;
/**
* Control the streaming output mode, that is, the content will contain the content has been output;
* set to True, will open the incremental output mode, the output will not contain the content has been output,
* you need to splice the whole output, refer to the streaming output sample code.
*/
private Boolean incrementalOutput = false;
/**
* A list of tools that the model can optionally call.
* Currently only functions are supported, and even if multiple functions are entered, the model will only select one to generate the result.
*/
private List tools;
@Override
public Float getTemperature() {
return this.temperature.floatValue();
}
public void setTemperature(Float temperature) {
this.temperature = temperature.doubleValue();
}
@Override
public Float getTopP() {
return this.topP.floatValue();
}
public void setTopP(Float topP) {
this.topP = topP.doubleValue();
}
@Override
public Integer getTopK() {
return this.topK;
}
public void setTopK(Integer topK) {
this.topK = topK;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public Integer getSeed() {
return seed;
}
public String getResultFormat() {
return resultFormat;
}
public void setResultFormat(String resultFormat) {
this.resultFormat = resultFormat;
}
public void setSeed(Integer seed) {
this.seed = seed;
}
public Integer getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
public Float getRepetitionPenalty() {
return repetitionPenalty.floatValue();
}
public void setRepetitionPenalty(Float repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty.doubleValue();
}
public List getStop() {
return stop;
}
public void setStop(List stop) {
this.stop = stop;
}
public Boolean getStream() {
return stream;
}
public void setStream(Boolean stream) {
this.stream = stream;
}
public Boolean getEnableSearch() {
return enableSearch;
}
public void setEnableSearch(Boolean enableSearch) {
this.enableSearch = enableSearch;
}
public Boolean getIncrementalOutput() {
return incrementalOutput;
}
public void setIncrementalOutput(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}
public List getTools() {
return tools;
}
public void setTools(List tools) {
this.tools = tools;
}
private List functionCallbacks = new ArrayList<>();
private Set functions = new HashSet<>();
@Override
public List getFunctionCallbacks() {
return this.functionCallbacks;
}
@Override
public void setFunctionCallbacks(List functionCallbacks) {
this.functionCallbacks = functionCallbacks;
}
@Override
public Set getFunctions() {
return this.functions;
}
@Override
public void setFunctions(Set functions) {
this.functions = functions;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiChatOptions that = (TongYiChatOptions) o;
return Objects.equals(model, that.model)
&& Objects.equals(seed, that.seed)
&& Objects.equals(maxTokens, that.maxTokens)
&& Objects.equals(topP, that.topP)
&& Objects.equals(topK, that.topK)
&& Objects.equals(repetitionPenalty, that.repetitionPenalty)
&& Objects.equals(temperature, that.temperature)
&& Objects.equals(stop, that.stop)
&& Objects.equals(stream, that.stream)
&& Objects.equals(enableSearch, that.enableSearch)
&& Objects.equals(resultFormat, that.resultFormat)
&& Objects.equals(incrementalOutput, that.incrementalOutput)
&& Objects.equals(tools, that.tools)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(functions, that.functions);
}
@Override
public int hashCode() {
return Objects.hash(
model,
seed,
maxTokens,
topP,
topK,
repetitionPenalty,
temperature,
stop,
stream,
enableSearch,
resultFormat,
incrementalOutput,
tools,
functionCallbacks,
functions
);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("TongYiChatOptions{");
sb.append(", model='").append(model).append('\'');
sb.append(", seed=").append(seed);
sb.append(", maxTokens=").append(maxTokens);
sb.append(", topP=").append(topP);
sb.append(", topK=").append(topK);
sb.append(", repetitionPenalty=").append(repetitionPenalty);
sb.append(", temperature=").append(temperature);
sb.append(", stop=").append(stop);
sb.append(", stream=").append(stream);
sb.append(", enableSearch=").append(enableSearch);
sb.append(", resultFormat='").append(resultFormat).append('\'');
sb.append(", incrementalOutput=").append(incrementalOutput);
sb.append(", tools=").append(tools);
sb.append(", functionCallbacks=").append(functionCallbacks);
sb.append(", functions=").append(functions);
sb.append('}');
return sb.toString();
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected TongYiChatOptions options;
public Builder() {
this.options = new TongYiChatOptions();
}
public Builder(TongYiChatOptions options) {
this.options = options;
}
public Builder withModel(String model) {
this.options.model = model;
return this;
}
public Builder withMaxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
}
public Builder withResultFormat(String rf) {
this.options.resultFormat = rf;
return this;
}
public Builder withEnableSearch(Boolean enableSearch) {
this.options.enableSearch = enableSearch;
return this;
}
public Builder withFunctionCallbacks(List functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}
public Builder withFunctions(Set functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}
public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}
public Builder withSeed(Integer seed) {
this.options.seed = seed;
return this;
}
public Builder withStop(List stop) {
this.options.stop = stop;
return this;
}
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}
public Builder withTopK(Integer topK) {
this.options.topK = topK;
return this;
}
public Builder withRepetitionPenalty(Double repetitionPenalty) {
this.options.repetitionPenalty = repetitionPenalty;
return this;
}
public TongYiChatOptions build() {
return this.options;
}
}
}