package com.xmzs.midjourney.controller;
|
|
import cn.hutool.core.text.CharSequenceUtil;
|
import cn.hutool.core.util.RandomUtil;
|
import com.xmzs.midjourney.Constants;
|
import com.xmzs.midjourney.ProxyProperties;
|
import com.xmzs.midjourney.ReturnCode;
|
import com.xmzs.midjourney.dto.BaseSubmitDTO;
|
import com.xmzs.midjourney.dto.SubmitBlendDTO;
|
import com.xmzs.midjourney.dto.SubmitChangeDTO;
|
import com.xmzs.midjourney.dto.SubmitDescribeDTO;
|
import com.xmzs.midjourney.dto.SubmitImagineDTO;
|
import com.xmzs.midjourney.dto.SubmitSimpleChangeDTO;
|
import com.xmzs.midjourney.enums.TaskAction;
|
import com.xmzs.midjourney.enums.TaskStatus;
|
import com.xmzs.midjourney.enums.TranslateWay;
|
import com.xmzs.midjourney.exception.BannedPromptException;
|
import com.xmzs.midjourney.result.SubmitResultVO;
|
import com.xmzs.midjourney.service.TaskService;
|
import com.xmzs.midjourney.service.TaskStoreService;
|
import com.xmzs.midjourney.service.TranslateService;
|
import com.xmzs.midjourney.support.Task;
|
import com.xmzs.midjourney.support.TaskCondition;
|
import com.xmzs.midjourney.util.BannedPromptUtils;
|
import com.xmzs.midjourney.util.ConvertUtils;
|
import com.xmzs.midjourney.util.MimeTypeUtils;
|
import com.xmzs.midjourney.util.SnowFlake;
|
import com.xmzs.midjourney.util.TaskChangeParams;
|
import eu.maxschuster.dataurl.DataUrl;
|
import eu.maxschuster.dataurl.DataUrlSerializer;
|
import eu.maxschuster.dataurl.IDataUrlSerializer;
|
import io.swagger.annotations.Api;
|
import io.swagger.annotations.ApiOperation;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RestController;
|
|
import java.net.MalformedURLException;
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Optional;
|
import java.util.Set;
|
import java.util.regex.Matcher;
|
import java.util.regex.Pattern;
|
|
@Api(tags = "任务提交")
|
@RestController
|
@RequestMapping("/mj/submit")
|
@RequiredArgsConstructor
|
public class SubmitController {
|
private final TranslateService translateService;
|
private final TaskStoreService taskStoreService;
|
private final ProxyProperties properties;
|
private final TaskService taskService;
|
|
@ApiOperation(value = "提交Imagine任务")
|
@PostMapping("/imagine")
|
public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) {
|
String prompt = imagineDTO.getPrompt();
|
if (CharSequenceUtil.isBlank(prompt)) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空");
|
}
|
prompt = prompt.trim();
|
Task task = newTask(imagineDTO);
|
task.setAction(TaskAction.IMAGINE);
|
task.setPrompt(prompt);
|
String promptEn = translatePrompt(prompt);
|
try {
|
BannedPromptUtils.checkBanned(promptEn);
|
} catch (BannedPromptException e) {
|
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
|
.setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage());
|
}
|
List<String> base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>());
|
if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) {
|
base64Array.add(imagineDTO.getBase64());
|
}
|
List<DataUrl> dataUrls;
|
try {
|
dataUrls = ConvertUtils.convertBase64Array(base64Array);
|
} catch (MalformedURLException e) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
}
|
task.setPromptEn(promptEn);
|
task.setDescription("/imagine " + prompt);
|
return this.taskService.submitImagine(task, dataUrls);
|
}
|
|
@ApiOperation(value = "绘图变化-simple")
|
@PostMapping("/simple-change")
|
public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
|
TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent());
|
if (changeParams == null) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误");
|
}
|
SubmitChangeDTO changeDTO = new SubmitChangeDTO();
|
changeDTO.setAction(changeParams.getAction());
|
changeDTO.setTaskId(changeParams.getId());
|
changeDTO.setIndex(changeParams.getIndex());
|
changeDTO.setState(simpleChangeDTO.getState());
|
changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook());
|
return change(changeDTO);
|
}
|
|
@ApiOperation(value = "绘图变化")
|
@PostMapping("/change")
|
public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) {
|
if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空");
|
}
|
if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误");
|
}
|
String description = "/up " + changeDTO.getTaskId();
|
if (TaskAction.REROLL.equals(changeDTO.getAction())) {
|
description += " R";
|
} else {
|
description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex();
|
}
|
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
TaskCondition condition = new TaskCondition().setDescription(description);
|
Task existTask = this.taskStoreService.findOne(condition);
|
if (existTask != null) {
|
return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId())
|
.setProperty("status", existTask.getStatus())
|
.setProperty("imageUrl", existTask.getImageUrl());
|
}
|
}
|
Task targetTask = this.taskStoreService.get(changeDTO.getTaskId());
|
if (targetTask == null) {
|
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效");
|
}
|
if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误");
|
}
|
if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化");
|
}
|
Task task = newTask(changeDTO);
|
task.setAction(changeDTO.getAction());
|
task.setPrompt(targetTask.getPrompt());
|
task.setPromptEn(targetTask.getPromptEn());
|
task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT));
|
task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID));
|
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID));
|
task.setDescription(description);
|
int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS);
|
String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID);
|
String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH);
|
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
} else if (TaskAction.VARIATION.equals(changeDTO.getAction())) {
|
return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
} else {
|
return this.taskService.submitReroll(task, messageId, messageHash, messageFlags);
|
}
|
}
|
|
@ApiOperation(value = "提交Describe任务")
|
@PostMapping("/describe")
|
public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) {
|
if (CharSequenceUtil.isBlank(describeDTO.getBase64())) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空");
|
}
|
IDataUrlSerializer serializer = new DataUrlSerializer();
|
DataUrl dataUrl;
|
try {
|
dataUrl = serializer.unserialize(describeDTO.getBase64());
|
} catch (MalformedURLException e) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
}
|
Task task = newTask(describeDTO);
|
task.setAction(TaskAction.DESCRIBE);
|
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
task.setDescription("/describe " + taskFileName);
|
return this.taskService.submitDescribe(task, dataUrl);
|
}
|
|
@ApiOperation(value = "提交Blend任务")
|
@PostMapping("/blend")
|
public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) {
|
List<String> base64Array = blendDTO.getBase64Array();
|
if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误");
|
}
|
if (blendDTO.getDimensions() == null) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误");
|
}
|
IDataUrlSerializer serializer = new DataUrlSerializer();
|
List<DataUrl> dataUrlList = new ArrayList<>();
|
try {
|
for (String base64 : base64Array) {
|
DataUrl dataUrl = serializer.unserialize(base64);
|
dataUrlList.add(dataUrl);
|
}
|
} catch (MalformedURLException e) {
|
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
}
|
Task task = newTask(blendDTO);
|
task.setAction(TaskAction.BLEND);
|
task.setDescription("/blend " + task.getId() + " " + dataUrlList.size());
|
return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions());
|
}
|
|
private Task newTask(BaseSubmitDTO base) {
|
Task task = new Task();
|
task.setId(System.currentTimeMillis() + RandomUtil.randomNumbers(3));
|
task.setSubmitTime(System.currentTimeMillis());
|
task.setState(base.getState());
|
String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook();
|
task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);
|
task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId());
|
return task;
|
}
|
|
private String translatePrompt(String prompt) {
|
if (TranslateWay.NULL.equals(this.properties.getTranslateWay()) || CharSequenceUtil.isBlank(prompt)) {
|
return prompt;
|
}
|
List<String> imageUrls = new ArrayList<>();
|
Matcher imageMatcher = Pattern.compile("https?://[a-z0-9-_:@&?=+,.!/~*'%$]+\\x20+", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
while (imageMatcher.find()) {
|
imageUrls.add(imageMatcher.group(0));
|
}
|
String paramStr = "";
|
Matcher paramMatcher = Pattern.compile("\\x20+-{1,2}[a-z]+.*$", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
if (paramMatcher.find()) {
|
paramStr = paramMatcher.group(0);
|
}
|
String imageStr = CharSequenceUtil.join("", imageUrls);
|
String text = prompt.substring(imageStr.length(), prompt.length() - paramStr.length());
|
if (CharSequenceUtil.isNotBlank(text)) {
|
text = this.translateService.translateToEnglish(text).trim();
|
}
|
return imageStr + text + paramStr;
|
}
|
|
}
|