ForkJoin
ForkJoin
图片来源:狂神说Java
ForkJoin 任务窃取
这个任务中维护的都是双端队列
B的任务执行完了,去帮助A执行
图片来源:狂神说Java
package com.sw.forkjoin;
import java.util.concurrent.RecursiveTask;
/**
* @Author suaxi
* @Date 2021/2/16 17:24
* 1、通过forkjoinPool来执行
* 2、计算任务forkjoinPool.execute(ForkJoinTask task)
* 3、计算类要继承 ForkJoinTask
*/
public class ForkJoinTest extends RecursiveTask<Long> {
private Long start;
private Long end;
//临界值
private Long temp = 1000L;
public ForkJoinTest(Long start, Long end) {
this.start = start;
this.end = end;
}
//计算任务
@Override
protected Long compute() {
if ((end-start)<temp){
Long sum = 0L;
for (Long i = start; i <= end; i++) {
sum += i;
}
return sum;
}else { //forkjoin递归
long middle = (start + end)/2;
ForkJoinTest task1 = new ForkJoinTest(start, middle);
task1.fork(); //拆分任务,把任务压入线程队列
ForkJoinTest task2 = new ForkJoinTest(middle+1, end);
task2.fork(); //同理
return task1.join() + task2.join();
}
}
}
测试类:
package com.sw.forkjoin;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.stream.LongStream;
/**
* @Author suaxi
* @Date 2021/2/16 17:37
*/
public class Test {
public static void main(String[] args) throws ExecutionException, InterruptedException {
//test01(); //计算时间=7166
//test02(); //计算时间=5244
test03(); //计算时间=312
}
//常规计算方法
public static void test01(){
Long sum =0L;
long start = System.currentTimeMillis();
for (Long i = 1L; i < 10_0000_0000; i++) {
sum += i;
}
long end = System.currentTimeMillis();
System.out.println("sum="+sum+" 计算时间="+ (end - start));
}
//ForkJoin
public static void test02() throws ExecutionException, InterruptedException {
long start = System.currentTimeMillis();
//1、通过ForkJoinPool执行
ForkJoinPool forkJoinPool = new ForkJoinPool();
ForkJoinTask<Long> task = new ForkJoinTest(0L, 10_0000_0000L);
//2、计算任务forkjoinPool.execute(ForkJoinTask task)
ForkJoinTask<Long> submit = forkJoinPool.submit(task);
Long sum = submit.get();
long end = System.currentTimeMillis();
System.out.println("sum="+sum+" 计算时间="+ (end - start));
}
//Stream并行流
public static void test03(){
long start = System.currentTimeMillis();
//parallel并行
Long sum = LongStream.rangeClosed(0L,10_0000_0000L).parallel().reduce(0,Long::sum);
long end = System.currentTimeMillis();
System.out.println("sum="+sum+" 计算时间="+ (end - start));
}
}
评论 (0)