聊一聊线程变量绑定之TransmittableThreadLocal
- 2019 年 12 月 19 日
- 筆記
上一篇中我们知道 InheritableThreadLocal 在线程复用场景下是无法进行 ThreadLocal 值传递的。TransmittableThreadLocal(TTL) 是 Alibaba 开源的,用于解决在使用线程池等会池化复用线程的组件情况下,提供 ThreadLocal 值的传递功能,解决异步执行时上下文传递的问题。TransmittableThreadLocal 需要配合 TTL 提供的 TtlExecutors、TtlRunnable 和 TtlCallable 使用,也可以使用 Java Agent 无侵入式实现线程池的传递。另外它继承自 InheritableThreadLocal。
示例
@Test public void testTtlRunnableTransmittableThreadLocalByThreadPool(){ TransmittableThreadLocal threadLocal = new TransmittableThreadLocal(); IntStream.range(0,10).forEach(i -> { System.out.println(i); threadLocal.set(i); service.execute(TtlRunnable.get(() -> { System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get()); })); try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } }); }
输出结果:
0 pool-1-thread-1:0 1 pool-1-thread-1:1 2 pool-1-thread-1:2 3 pool-1-thread-1:3 4 pool-1-thread-1:4 5 pool-1-thread-1:5 6 pool-1-thread-1:6 7 pool-1-thread-1:7 8 pool-1-thread-1:8 9 pool-1-thread-1:9
private ExecutorService service = Executors.newFixedThreadPool(1); @Test public void testTransmittableThreadLocalByTtlThreadPool(){ service = TtlExecutors.getTtlExecutorService(service); TransmittableThreadLocal threadLocal = new TransmittableThreadLocal(); IntStream.range(0,10).forEach(i -> { System.out.println(i); threadLocal.set(i); service.execute(() -> System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get() )); try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } }); }
输出结果:
0 pool-1-thread-1:0 1 pool-1-thread-1:1 2 pool-1-thread-1:2 3 pool-1-thread-1:3 4 pool-1-thread-1:4 5 pool-1-thread-1:5 6 pool-1-thread-1:6 7 pool-1-thread-1:7 8 pool-1-thread-1:8 9 pool-1-thread-1:9
可以看出,在配合 TtlExecutors、TtlRunnable 和 TtlCallable 时,TransmittableThreadLocal 可以实现 InheritableThreadLocal 实现不了的效果——线程复用条件下的 ThreadLocal 变量传递。
源码
holder
// Note about holder: // 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation), // but it is used as *set*. // 2. WeakHashMap support null value. private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder = new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() { @Override protected Map<TransmittableThreadLocal<?>, ?> initialValue() { return new WeakHashMap<TransmittableThreadLocal<?>, Object>(); } @Override protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) { return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue); } };
holder 是一个 InheritableThreadLocal 类型的变量,这里使用了一个 WeakHashMap 来存放 initialValue 和 childValue。
- initialValue 是初始化时使用的。
- childValue 在上一节讲 InheritableThreadLocal 时有提到过,是在子线程创建 ThreadLocalMap 时拷贝父线程的 ThreadLocalMap 时使用的。这里是将 parentValue 包在一个 WeakHashMap 中的。
set 方法及相关方法
@Override public final void set(T value) { super.set(value); // may set null to remove value if (null == value) removeValue(); else addValue(); } private void removeValue() { holder.get().remove(this); } private void addValue() { if (!holder.get().containsKey(this)) { holder.get().put(this, null); // WeakHashMap supports null value. } }
holder.get()获取到的是每次添加值或删除值时都会操作 holder。holder.get()获取到的是一个 Key 为 TransmittableThreadLocal,值为 Object 的 Map。这里在 addValue 时 key 为 TransmittableThreadLocal,值为 null 是为了利用 WeakHashMap 的特性,在没有引用指向 this 时,jvm 会在需要的时候进行 gc。
get 方法
@Override public final T get() { T value = super.get(); if (null != value) addValue(); return value; }
主要还是利用父类的 get 方法,这里主要是添加了一个 holder 对 ThreadLocal 的管理。
TtlRunnable
private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) { //相当于是做一个快照,放在AtomicReference中(原子引用) this.capturedRef = new AtomicReference<Object>(capture()); this.runnable = runnable; this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun; } @Override public void run() { Object captured = capturedRef.get(); if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) { throw new IllegalStateException("TTL value reference is released after run!"); } //进行上下文的备份 Object backup = replay(captured); try { runnable.run(); } finally { //恢复备份 restore(backup); } }
我们继续看下 replay 和 restore 方法:
@Nonnull public static Object replay(@Nonnull Object captured) { //快照的TransmittableThreadLocal map @SuppressWarnings("unchecked") Map<TransmittableThreadLocal<?>, Object> capturedMap = (Map<TransmittableThreadLocal<?>, Object>) captured; //用于备份的TransmittableThreadLocal map Map<TransmittableThreadLocal<?>, Object> backup = new HashMap<TransmittableThreadLocal<?>, Object>(); for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator(); iterator.hasNext(); ) { Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next(); TransmittableThreadLocal<?> threadLocal = next.getKey(); // backup backup.put(threadLocal, threadLocal.get()); // clear the TTL values that is not in captured // avoid the extra TTL values after replay when run task if (!capturedMap.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } // set values to captured TTL setTtlValuesTo(capturedMap); // call beforeExecute callback doExecuteCallback(true); return backup; } public static void restore(@Nonnull Object backup) { @SuppressWarnings("unchecked") Map<TransmittableThreadLocal<?>, Object> backupMap = (Map<TransmittableThreadLocal<?>, Object>) backup; // call afterExecute callback doExecuteCallback(false); for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator(); iterator.hasNext(); ) { Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next(); TransmittableThreadLocal<?> threadLocal = next.getKey(); // clear the TTL values that is not in backup // avoid the extra TTL values after restore if (!backupMap.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } // restore TTL values setTtlValuesTo(backupMap); } private static void setTtlValuesTo(@Nonnull Map<TransmittableThreadLocal<?>, Object> ttlValues) { for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : ttlValues.entrySet()) { @SuppressWarnings("unchecked") TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey(); threadLocal.set(entry.getValue()); } } private static void doExecuteCallback(boolean isBefore) { for (Map.Entry<TransmittableThreadLocal<?>, ?> entry : holder.get().entrySet()) { TransmittableThreadLocal<?> threadLocal = entry.getKey(); try { if (isBefore) threadLocal.beforeExecute(); else threadLocal.afterExecute(); } catch (Throwable t) { if (logger.isLoggable(Level.WARNING)) { logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t); } } } }
在真正地执行 run 方法前会选对之前线程的 TransmittableThreadLocal 进行备份,在执行完成后进行 restore。其中 beforeExecute 和 afterExecute 是执行之前和之后的回调方法。归纳起来主要有两步:
- 在执行 run 方法前将当前线程的上下文 copy 一份做备份。
- 在执行完 run 方法之后使用这个备份调用 TransmittableThreadLocal.Transmitter.restore 并把备份的上下文传入,恢复备份的上下文,把后面新增的上下文删除,并重新把上下文复制到当前线程。
Ttl 线程池
线程池执行时,执行了 ExecutorTtlWrapper 的 execute 方法,execute 方法中调用了 TtlRunnable.get(command) ,get 方法中创建了一个 TtlRunnable 对象返回了。有兴趣的可以自己去看。
地址
- https://github.com/alibaba/transmittable-thread-local
应用
log4j2 MDC:
<dependency> <groupId>com.alibaba</groupId> <artifactId>log4j2-ttl-thread-context-map</artifactId> <version>1.2.0</version> </dependency>