死磕 java執行緒系列之ForkJoinPool深入解析

  • 2019 年 11 月 9 日
  • 筆記

forkjoinpool

(手機橫屏看源碼更方便)


註:java源碼分析部分如無特殊說明均基於 java8 版本。

註:本文基於ForkJoinPool分治執行緒池類。

簡介

隨著在硬體上多核處理器的發展和廣泛使用,並發編程成為程式設計師必須掌握的一門技術,在面試中也經常考查面試者並發相關的知識。

今天,我們就來看一道面試題:

如何充分利用多核CPU,計算很大數組中所有整數的和?

剖析

  • 單執行緒相加?

我們最容易想到就是單執行緒相加,一個for循環搞定。

  • 執行緒池相加?

如果進一步優化,我們會自然而然地想到使用執行緒池來分段相加,最後再把每個段的結果相加。

  • 其它?

Yes,就是我們今天的主角——ForkJoinPool,但是它要怎麼實現呢?似乎沒怎麼用過哈^^

三種實現

OK,剖析完了,我們直接來看三種實現,不墨跡,直接上菜。

/**   * 計算1億個整數的和   */  public class ForkJoinPoolTest01 {      public static void main(String[] args) throws ExecutionException, InterruptedException {          // 構造數據          int length = 100000000;          long[] arr = new long[length];          for (int i = 0; i < length; i++) {              arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);          }          // 單執行緒          singleThreadSum(arr);          // ThreadPoolExecutor執行緒池          multiThreadSum(arr);          // ForkJoinPool執行緒池          forkJoinSum(arr);        }        private static void singleThreadSum(long[] arr) {          long start = System.currentTimeMillis();            long sum = 0;          for (int i = 0; i < arr.length; i++) {              // 模擬耗時,本文由公從號「彤哥讀源碼」原創              sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);          }            System.out.println("sum: " + sum);          System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));        }        private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {          long start = System.currentTimeMillis();            int count = 8;          ExecutorService threadPool = Executors.newFixedThreadPool(count);          List<Future<Long>> list = new ArrayList<>();          for (int i = 0; i < count; i++) {              int num = i;              // 分段提交任務              Future<Long> future = threadPool.submit(() -> {                  long sum = 0;                  for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {                      try {                          // 模擬耗時                          sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);                      } catch (Exception e) {                          e.printStackTrace();                      }                  }                  return sum;              });              list.add(future);          }            // 每個段結果相加          long sum = 0;          for (Future<Long> future : list) {              sum += future.get();          }            System.out.println("sum: " + sum);          System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));      }        private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {          long start = System.currentTimeMillis();            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();          // 提交任務          ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));          // 獲取結果          Long sum = forkJoinTask.get();            forkJoinPool.shutdown();            System.out.println("sum: " + sum);          System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));      }        private static class SumTask extends RecursiveTask<Long> {          private long[] arr;          private int from;          private int to;            public SumTask(long[] arr, int from, int to) {              this.arr = arr;              this.from = from;              this.to = to;          }            @Override          protected Long compute() {              // 小於1000的時候直接相加,可靈活調整              if (to - from <= 1000) {                  long sum = 0;                  for (int i = from; i < to; i++) {                      // 模擬耗時                      sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);                  }                  return sum;              }                // 分成兩段任務,本文由公從號「彤哥讀源碼」原創              int middle = (from + to) / 2;              SumTask left = new SumTask(arr, from, middle);              SumTask right = new SumTask(arr, middle, to);                // 提交左邊的任務              left.fork();              // 右邊的任務直接利用當前執行緒計算,節約開銷              Long rightResult = right.compute();              // 等待左邊計算完畢              Long leftResult = left.join();              // 返回結果              return leftResult + rightResult;          }      }  }

彤哥偷偷地告訴你,實際上計算1億個整數相加,單執行緒是最快的,我的電腦大概是100ms左右,使用執行緒池反而會變慢。

所以,為了演示ForkJoinPool的牛逼之處,我把每個數都/3*3/3*3/3*3/3*3/3*3了一頓操作,用來模擬計算耗時。

來看結果:

sum: 107352457433800662  single thread elapse: 789  sum: 107352457433800662  multi thread elapse: 228  sum: 107352457433800662  fork join elapse: 189

可以看到,ForkJoinPool相對普通執行緒池還是有很大提升的。

問題:普通執行緒池能否實現ForkJoinPool這種計算方式呢,即大任務拆中任務,中任務拆小任務,最後再匯總?

forkjoinpool

你可以試試看(-᷅_-᷄)

OK,下面我們正式進入ForkJoinPool的解析。

分治法

  • 基本思想

把一個規模大的問題劃分為規模較小的子問題,然後分而治之,最後合併子問題的解得到原問題的解。

  • 步驟

(1)分割原問題:

(2)求解子問題:

(3)合併子問題的解為原問題的解。

在分治法中,子問題一般是相互獨立的,因此,經常通過遞歸調用演算法來求解子問題。

  • 典型應用場景

(1)二分搜索

(2)大整數乘法

(3)Strassen矩陣乘法

(4)棋盤覆蓋

(5)歸併排序

(6)快速排序

(7)線性時間選擇

(8)漢諾塔

ForkJoinPool繼承體系

ForkJoinPool是 java 7 中新增的執行緒池類,它的繼承體系如下:

forkjoinpool

ForkJoinPool和ThreadPoolExecutor都是繼承自AbstractExecutorService抽象類,所以它和ThreadPoolExecutor的使用幾乎沒有多少區別,除了任務變成了ForkJoinTask以外。

這裡又運用到了一種很重要的設計原則——開閉原則——對修改關閉,對擴展開放。

可見整個執行緒池體系一開始的介面設計就很好,新增一個執行緒池類,不會對原有的程式碼造成干擾,還能利用原有的特性。

ForkJoinTask

兩個主要方法

  • fork()

fork()方法類似於執行緒的Thread.start()方法,但是它不是真的啟動一個執行緒,而是將任務放入到工作隊列中。

  • join()

join()方法類似於執行緒的Thread.join()方法,但是它不是簡單地阻塞執行緒,而是利用工作執行緒運行其它任務。當一個工作執行緒中調用了join()方法,它將處理其它任務,直到注意到目標子任務已經完成了。

三個子類

  • RecursiveAction

無返回值任務。

  • RecursiveTask

有返回值任務。

  • CountedCompleter

無返回值任務,完成任務後可以觸發回調。

ForkJoinPool內部原理

ForkJoinPool內部使用的是「工作竊取」演算法實現的。

forkjoinpool

(1)每個工作執行緒都有自己的工作隊列WorkQueue;

(2)這是一個雙端隊列,它是執行緒私有的;

(3)ForkJoinTask中fork的子任務,將放入運行該任務的工作執行緒的隊頭,工作執行緒將以LIFO的順序來處理工作隊列中的任務;

(4)為了最大化地利用CPU,空閑的執行緒將從其它執行緒的隊列中「竊取」任務來執行;

(5)從工作隊列的尾部竊取任務,以減少競爭;

(6)雙端隊列的操作:push()/pop()僅在其所有者工作執行緒中調用,poll()是由其它執行緒竊取任務時調用的;

(7)當只剩下最後一個任務時,還是會存在競爭,是通過CAS來實現的;

forkjoinpool

ForkJoinPool最佳實踐

(1)最適合的是計算密集型任務,本文由公從號「彤哥讀源碼」原創;

(2)在需要阻塞工作執行緒時,可以使用ManagedBlocker;

(3)不應該在RecursiveTask的內部使用ForkJoinPool.invoke()/invokeAll();

總結

(1)ForkJoinPool特別適合於「分而治之」演算法的實現;

(2)ForkJoinPool和ThreadPoolExecutor是互補的,不是誰替代誰的關係,二者適用的場景不同;

(3)ForkJoinTask有兩個核心方法——fork()和join(),有三個重要子類——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool內部基於「工作竊取」演算法實現;

(5)每個執行緒有自己的工作隊列,它是一個雙端隊列,自己從隊列頭存取任務,其它執行緒從尾部竊取任務;

(6)ForkJoinPool最適合於計算密集型任務,但也可以使用ManagedBlocker以便用於阻塞型任務;

(7)RecursiveTask內部可以少調用一次fork(),利用當前執行緒處理,這是一種技巧;

彩蛋

ManagedBlocker怎麼使用?

答:ManagedBlocker相當於明確告訴ForkJoinPool框架要阻塞了,ForkJoinPool就會啟另一個執行緒來運行任務,以最大化地利用CPU。

請看下面的例子,自己琢磨哈^^。

/**   * 斐波那契數列   * 一個數是它前面兩個數之和   * 1,1,2,3,5,8,13,21   */  public class Fibonacci {        public static void main(String[] args) {          long time = System.currentTimeMillis();          Fibonacci fib = new Fibonacci();          int result = fib.f(1_000).bitCount();          time = System.currentTimeMillis() - time;          System.out.println("result,本文由公從號「彤哥讀源碼」原創 = " + result);          System.out.println("test1_000() time = " + time);      }        public BigInteger f(int n) {          Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();          cache.put(0, BigInteger.ZERO);          cache.put(1, BigInteger.ONE);          return f(n, cache);      }        private final BigInteger RESERVED = BigInteger.valueOf(-1000);        public BigInteger f(int n, Map<Integer, BigInteger> cache) {          BigInteger result = cache.putIfAbsent(n, RESERVED);          if (result == null) {                int half = (n + 1) / 2;                RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() {                  @Override                  protected BigInteger compute() {                      return f(half - 1, cache);                  }              };              f0_task.fork();                BigInteger f1 = f(half, cache);              BigInteger f0 = f0_task.join();                long time = n > 10_000 ? System.currentTimeMillis() : 0;              try {                    if (n % 2 == 1) {                      result = f0.multiply(f0).add(f1.multiply(f1));                  } else {                      result = f0.shiftLeft(1).add(f1).multiply(f1);                  }                  synchronized (RESERVED) {                      cache.put(n, result);                      RESERVED.notifyAll();                  }              } finally {                  time = n > 10_000 ? System.currentTimeMillis() - time : 0;                  if (time > 50)                      System.out.printf("f(%d) took %d%n", n, time);              }          } else if (result == RESERVED) {              try {                  ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache);                  ForkJoinPool.managedBlock(blocker);                  result = blocker.result;              } catch (InterruptedException e) {                  throw new CancellationException("interrupted");              }            }          return result;          // return f(n - 1).add(f(n - 2));      }        private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker {          private BigInteger result;          private final int n;          private final Map<Integer, BigInteger> cache;            public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {              this.n = n;              this.cache = cache;          }            @Override          public boolean block() throws InterruptedException {              synchronized (RESERVED) {                  while (!isReleasable()) {                      RESERVED.wait();                  }              }              return true;          }            @Override          public boolean isReleasable() {              return (result = cache.get(n)) != RESERVED;          }      }  }

歡迎關注我的公眾號「彤哥讀源碼」,查看更多源碼系列文章, 與彤哥一起暢遊源碼的海洋。

qrcode