聊一聊線程變量綁定之TransmittableThreadLocal

  • 2019 年 12 月 19 日
  • 筆記

上一篇中我們知道 InheritableThreadLocal 在線程復用場景下是無法進行 ThreadLocal 值傳遞的。TransmittableThreadLocal(TTL) 是 Alibaba 開源的,用於解決在使用線程池等會池化復用線程的組件情況下,提供 ThreadLocal 值的傳遞功能,解決異步執行時上下文傳遞的問題。TransmittableThreadLocal 需要配合 TTL 提供的 TtlExecutors、TtlRunnable 和 TtlCallable 使用,也可以使用 Java Agent 無侵入式實現線程池的傳遞。另外它繼承自 InheritableThreadLocal。

示例

@Test      public void testTtlRunnableTransmittableThreadLocalByThreadPool(){          TransmittableThreadLocal threadLocal = new TransmittableThreadLocal();          IntStream.range(0,10).forEach(i -> {              System.out.println(i);              threadLocal.set(i);              service.execute(TtlRunnable.get(() -> {                  System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());              }));              try {                  Thread.sleep(1000);              } catch (InterruptedException e) {                  e.printStackTrace();              }          });      }

輸出結果:

0  pool-1-thread-1:0  1  pool-1-thread-1:1  2  pool-1-thread-1:2  3  pool-1-thread-1:3  4  pool-1-thread-1:4  5  pool-1-thread-1:5  6  pool-1-thread-1:6  7  pool-1-thread-1:7  8  pool-1-thread-1:8  9  pool-1-thread-1:9

private ExecutorService service = Executors.newFixedThreadPool(1);        @Test      public void testTransmittableThreadLocalByTtlThreadPool(){          service = TtlExecutors.getTtlExecutorService(service);          TransmittableThreadLocal threadLocal = new TransmittableThreadLocal();          IntStream.range(0,10).forEach(i -> {              System.out.println(i);              threadLocal.set(i);              service.execute(() ->                  System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get()              ));              try {                  Thread.sleep(1000);              } catch (InterruptedException e) {                  e.printStackTrace();              }          });      }

輸出結果:

0  pool-1-thread-1:0  1  pool-1-thread-1:1  2  pool-1-thread-1:2  3  pool-1-thread-1:3  4  pool-1-thread-1:4  5  pool-1-thread-1:5  6  pool-1-thread-1:6  7  pool-1-thread-1:7  8  pool-1-thread-1:8  9  pool-1-thread-1:9

可以看出,在配合 TtlExecutors、TtlRunnable 和 TtlCallable 時,TransmittableThreadLocal 可以實現 InheritableThreadLocal 實現不了的效果——線程復用條件下的 ThreadLocal 變量傳遞。

源碼

holder

// Note about holder:      // 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation),      //    but it is used as *set*.      // 2. WeakHashMap support null value.      private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =              new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {                  @Override                  protected Map<TransmittableThreadLocal<?>, ?> initialValue() {                      return new WeakHashMap<TransmittableThreadLocal<?>, Object>();                  }                    @Override                  protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {                      return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);                  }              };

holder 是一個 InheritableThreadLocal 類型的變量,這裡使用了一個 WeakHashMap 來存放 initialValue 和 childValue。

  • initialValue 是初始化時使用的。
  • childValue 在上一節講 InheritableThreadLocal 時有提到過,是在子線程創建 ThreadLocalMap 時拷貝父線程的 ThreadLocalMap 時使用的。這裡是將 parentValue 包在一個 WeakHashMap 中的。

set 方法及相關方法

@Override      public final void set(T value) {          super.set(value);          // may set null to remove value          if (null == value) removeValue();          else addValue();      }      private void removeValue() {          holder.get().remove(this);      }      private void addValue() {          if (!holder.get().containsKey(this)) {              holder.get().put(this, null); // WeakHashMap supports null value.          }      }

holder.get()獲取到的是每次添加值或刪除值時都會操作 holder。holder.get()獲取到的是一個 Key 為 TransmittableThreadLocal,值為 Object 的 Map。這裡在 addValue 時 key 為 TransmittableThreadLocal,值為 null 是為了利用 WeakHashMap 的特性,在沒有引用指向 this 時,jvm 會在需要的時候進行 gc。

get 方法

@Override      public final T get() {          T value = super.get();          if (null != value) addValue();          return value;      }

主要還是利用父類的 get 方法,這裡主要是添加了一個 holder 對 ThreadLocal 的管理。

TtlRunnable

private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {      //相當於是做一個快照,放在AtomicReference中(原子引用)      this.capturedRef = new AtomicReference<Object>(capture());      this.runnable = runnable;      this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;  }     @Override  public void run() {      Object captured = capturedRef.get();      if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {          throw new IllegalStateException("TTL value reference is released after run!");      }      //進行上下文的備份      Object backup = replay(captured);      try {          runnable.run();      } finally {          //恢復備份          restore(backup);      }  }

我們繼續看下 replay 和 restore 方法:

@Nonnull  public static Object replay(@Nonnull Object captured) {      //快照的TransmittableThreadLocal map      @SuppressWarnings("unchecked")      Map<TransmittableThreadLocal<?>, Object> capturedMap = (Map<TransmittableThreadLocal<?>, Object>) captured;      //用於備份的TransmittableThreadLocal map      Map<TransmittableThreadLocal<?>, Object> backup = new HashMap<TransmittableThreadLocal<?>, Object>();        for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();              iterator.hasNext(); ) {          Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();          TransmittableThreadLocal<?> threadLocal = next.getKey();            // backup          backup.put(threadLocal, threadLocal.get());            // clear the TTL values that is not in captured          // avoid the extra TTL values after replay when run task          if (!capturedMap.containsKey(threadLocal)) {              iterator.remove();              threadLocal.superRemove();          }      }        // set values to captured TTL      setTtlValuesTo(capturedMap);        // call beforeExecute callback      doExecuteCallback(true);        return backup;  }     public static void restore(@Nonnull Object backup) {      @SuppressWarnings("unchecked")      Map<TransmittableThreadLocal<?>, Object> backupMap = (Map<TransmittableThreadLocal<?>, Object>) backup;      // call afterExecute callback      doExecuteCallback(false);        for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();              iterator.hasNext(); ) {          Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();          TransmittableThreadLocal<?> threadLocal = next.getKey();            // clear the TTL values that is not in backup          // avoid the extra TTL values after restore          if (!backupMap.containsKey(threadLocal)) {              iterator.remove();              threadLocal.superRemove();          }      }        // restore TTL values      setTtlValuesTo(backupMap);  }      private static void setTtlValuesTo(@Nonnull Map<TransmittableThreadLocal<?>, Object> ttlValues) {      for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : ttlValues.entrySet()) {          @SuppressWarnings("unchecked")          TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey();          threadLocal.set(entry.getValue());      }  }    private static void doExecuteCallback(boolean isBefore) {          for (Map.Entry<TransmittableThreadLocal<?>, ?> entry : holder.get().entrySet()) {              TransmittableThreadLocal<?> threadLocal = entry.getKey();                try {                  if (isBefore) threadLocal.beforeExecute();                  else threadLocal.afterExecute();              } catch (Throwable t) {                  if (logger.isLoggable(Level.WARNING)) {                      logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t);                  }              }          }      }

在真正地執行 run 方法前會選對之前線程的 TransmittableThreadLocal 進行備份,在執行完成後進行 restore。其中 beforeExecute 和 afterExecute 是執行之前和之後的回調方法。歸納起來主要有兩步:

  • 在執行 run 方法前將當前線程的上下文 copy 一份做備份。
  • 在執行完 run 方法之後使用這個備份調用 TransmittableThreadLocal.Transmitter.restore 並把備份的上下文傳入,恢復備份的上下文,把後面新增的上下文刪除,並重新把上下文複製到當前線程。

Ttl 線程池

線程池執行時,執行了 ExecutorTtlWrapper 的 execute 方法,execute 方法中調用了 TtlRunnable.get(command) ,get 方法中創建了一個 TtlRunnable 對象返回了。有興趣的可以自己去看。

地址

  • https://github.com/alibaba/transmittable-thread-local

應用

log4j2 MDC:

<dependency>      <groupId>com.alibaba</groupId>      <artifactId>log4j2-ttl-thread-context-map</artifactId>      <version>1.2.0</version>  </dependency>