package com.xmzs.midjourney.controller; import cn.hutool.core.comparator.CompareUtil; import com.xmzs.midjourney.dto.TaskConditionDTO; import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer; import com.xmzs.midjourney.service.TaskStoreService; import com.xmzs.midjourney.support.Task; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiParam; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Objects; @Api(tags = "任务查询") @RestController @RequestMapping("/mj/task") @RequiredArgsConstructor public class TaskController { private final TaskStoreService taskStoreService; private final DiscordLoadBalancer discordLoadBalancer; @ApiOperation(value = "指定ID获取任务") @GetMapping("/{id}/fetch") public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) { return this.taskStoreService.get(id); } @ApiOperation(value = "查询任务队列") @GetMapping("/queue") public List queue() { return this.discordLoadBalancer.getQueueTaskIds().stream() .map(this.taskStoreService::get).filter(Objects::nonNull) .sorted(Comparator.comparing(Task::getSubmitTime)) .toList(); } @ApiOperation(value = "查询所有任务") @GetMapping("/list") public List list() { return this.taskStoreService.list().stream() .sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime())) .toList(); } @ApiOperation(value = "根据ID列表查询任务") @PostMapping("/list-by-condition") public List listByIds(@RequestBody TaskConditionDTO conditionDTO) { if (conditionDTO.getIds() == null) { return Collections.emptyList(); } return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList(); } }