在使用多线程执行任务时,有时需要取消执行的需求,即取消某个线程正在执行的任务,一般会在线程中加入一个取消标志,在任务执行期间不断检查标志,当检查到取消时抛出特定异常中断执行并进行取消后处理。这里说一下在springboot中,使用Async注解时,取消执行的实现方式。

自定义线程

要实现任务取消,首先要在线程中加入任务取消标志任务执行时不断检查该标志,当标志被置为true时,抛出特定异常结束执行,并在任务处理的最后捕获该异常进行善后处理,因此需要自定义一下Thread

package ink.labrador.taskcancel.wrapper;

public class TaskThreadWrapper extends Thread {
    private volatile boolean taskCanceled = false; // 取消标志

    public TaskThreadWrapper(ThreadGroup group, Runnable r, String threadName, int i) {
        super(group, r, threadName, i);
    }

    public TaskThreadWrapper(Runnable runnable) {
        super(runnable);
    }

    public TaskThreadWrapper() {
        super();
    }

    /**
    * 取消任务即为把标志置为true
    **/
    public void cancelTask() {
        this.taskCanceled = true;
    }

    public boolean isTaskCanceled() {
        return taskCanceled;
    }

    public void clearState() {
        this.taskCanceled = false;
    }
}

自定义线程工厂

springboot的线程池在创建新的线程时,会使用默认的线程工厂创建,创建的线程为原生的Thread,因此需要自定义一下线程工厂,以便使新创建的线程实例是我们自定义的线程:

package ink.labrador.taskcancel.wrapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

public class TaskThreadFactoryWrapper implements ThreadFactory {
    private final Logger logger = LoggerFactory.getLogger(ThreadFactory.class);

    private final ThreadGroup group;
    private final AtomicInteger threadNumber = new AtomicInteger(1);
    private final String namePrefix;

    public TaskThreadFactoryWrapper() {
        SecurityManager s = System.getSecurityManager();
        group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
        namePrefix = "Task-Thread-";
    }

    // 创建新的线程。线程池在创建线程时就是通过调用该方法创建的
    @Override
    public Thread newThread(Runnable r) {
        String threadName = namePrefix + threadNumber.getAndIncrement();
        Thread t = new TaskThreadWrapper(group, r, threadName, 0); // 使用自定义的线程
        logger.info(String.format("task thread %s was created ...", threadName));
        if (t.isDaemon()) {
            t.setDaemon(false);
        }
        if (t.getPriority() != Thread.NORM_PRIORITY) {
            t.setPriority(Thread.NORM_PRIORITY);
        }
        return t;
    }
}

配置线程池

@Async注解默认使用的线程池是SimpleAsyncTaskExecutor,这个线程池不复用线程,每次都会创建新的线程,一般不建议使用,这里自定义一下线程池的配置,使用ThreadPoolTaskExecutor线程池,并应用自定义的线程工厂:

package ink.labrador.taskcancel.configuration;

import ink.labrador.taskcancel.wrapper.TaskThreadFactoryWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.*;

@Configuration
public class TaskExecutorConfiguration  {

    /**
     * 自定义线程池
     * @return 线程池
     */
    @Bean
    public Executor geTaskExecutor() {
        int poolSize = 5;

        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(poolSize); // 核心线程数量
        executor.setMaxPoolSize(poolSize); // 最大线程数量
        executor.setQueueCapacity(poolSize * 2); // 最大任务排队数量,当排队的任务数量超出时,会产生异常
        executor.setThreadNamePrefix("Task-Worker - ");
        executor.setAwaitTerminationSeconds(15);
        executor.setThreadFactory(new TaskThreadFactoryWrapper()); // 应用自定义的线程工厂,创建的线程类型就都是TaskThreadWrapper了
        executor.initialize();
        return executor;
    }
}

定义任务取消异常

自定义一个异常,任务取消时抛出该异常:

package ink.labrador.taskcancel.exception;

import ink.labrador.taskcancel.wrapper.TaskThreadWrapper;

public class TaskCanceledException extends Exception{

    public TaskCanceledException() {
        super("task was canceled");
    }
}

定义任务执行上下文

任务执行上下文主要提供两个作用:

  • 缓存任务与线程的对应信息

    这样当取消任务时,可以根据任务的唯一标志(如id)找到对应的线程,进行取消动作。

  • 全局提供任务信息上下文

    通过ThreadLocal保存任务的关键信息(如id、名称、任务配置、任务中间状态等等),这样在整个执行过程中,可以随时拿到任务信息,不用在函数之间显示传递。

此外,还可对正在执行的任务数量进行简单计数。假设任务信息定义如下:

package ink.labrador.taskcancel.executor;
import lombok.*;

@Data
@NoArgsConstructor
public class TaskMeta {
    private Integer id;
    private String name;
    private Integer status;
}

定义任务执行上下文如下:

package ink.labrador.taskcancel.executor;

import ink.labrador.taskcancel.exception.TaskCanceledException;
import ink.labrador.taskcancel.wrapper.TaskThreadWrapper;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
 * 线程执行任务的上下文
 */
public class TaskExecutorContext implements AutoCloseable {
    static final ThreadLocal<TaskMeta> ctx = new ThreadLocal<>(); // 全局任务信息
    static final AtomicInteger runningTaskNumbers = new AtomicInteger(0); // 正在执行的任务数量计数
    // 任务id - 线程对应关系
    static final ConcurrentHashMap<Integer, TaskThreadWrapper> taskAndThreadHolder = new ConcurrentHashMap<>();

    public TaskExecutorContext(TaskMeta taskMeta) {
        ctx.set(taskMeta);
        taskAndThreadHolder.put(taskMeta.getId(), thread);
        runningTaskNumbers.incrementAndGet();
    }

    public static TaskMeta currentTaskMeta() {
        return ctx.get();
    }

    public static int runningTaskNumbers() {
        return runningTaskNumbers.get();
    }


    // 检查当前线程对应执行的任务的取消状态,也即检查任务取消标志
    public static void checkAndThrowExceptionIfTaskCanceled() throws TaskCanceledException {
        TaskThreadWrapper thread = (TaskThreadWrapper) Thread.currentThread();
        if (thread .isTaskCanceled()) {
            // 抛出任务取消异常
            throw new TaskCanceledException();
        }
    }

   // 获取可用的线程数量
    public static int availableWorkersNumber() {
        int poolSize = 5;
        return poolSize - runningTaskNumbers();
    }
 
    // 取消任务,即把对应任务的线程中的任务取消标志置为1
    public static void cancelTask(int taskId) {
        try {
            // 获取到执行任务的线程
            TaskThreadWrapper threadWrapper = taskAndThreadHolder.get(taskId);
            if (threadWrapper != null) {
                threadWrapper.cancelTask(); // 取消任务
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static ConcurrentHashMap<Integer, TaskThreadWrapper> getTaskAndThreadHolder() {
        return taskAndThreadHolder;
    }

    // 上下文关闭时(也即任务结束后)的善后处理
    @Override
    public void close() {
        try {
            TaskThreadWrapper thread = (TaskThreadWrapper) Thread.currentThread();
            runningTaskNumbers.decrementAndGet();
            taskAndThreadHolder.remove(currentTaskMeta().getId());
            thread.clearState();
            ctx.remove();
        } catch (Exception ignored) {}
    }
}

使用

在执行任务的入口函数上应用@Async注解,并用任务上下文包裹,在最外层捕获任务取消异常,如下:

package ink.labrador.taskcancel.executor;

import ink.labrador.taskcancel.service.TaskService;
import ink.labrador.taskcancel.exception.TaskCanceledException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Component
public class TaskExecutorService {
    private final Logger logger = LoggerFactory.getLogger(TaskExecutorService .class);
    @Autowired private TaskService taskService;

    // 任务执行入口
    @Async
    public void run(int taskId) {
        // 查询数据库获取到任务详细信息 ...
        TaskMeta taskMeta = taskService.getById(taskId);
        try(TaskExecutorContext taskExecutorContext = new TaskExecutorContext(taskMeta)) {
            try {
                processTask();
            } catch (TaskCanceledException taskCanceledException) {
                // 捕获任务取消异常
                logger.warn(String.format("任务被取消: %s ...", taskMeta.getName));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void processTask() throws TaskCanceledException {
        TaskMeta taskMeta = TaskExecutorContext.currentTaskMeta(); // 可通过上下文直接获取到任务信息
        logger.info(String.format("开始执行任务: %s ...", taskMeta.getName));
        while(...) {
            // 任务处理过程中不断检查任务是否被取消
            TaskExecutorContext.checkAndThrowExceptionIfTaskCanceled();
        }
    }
}

当需要取消任务时,可以在任何地方执行取消操作,比如提供任务取消的rpc接口等,只需要调用取消函数即可:

TaskExecutorContext.cancelTask(taskId)

这样,当任务处理走到检查任务取消标志的地方时就会直接抛出异常,中断执行。