public class CountDownLatchTest { private static class WorkThread extends Thread { private CountDownLatch cdl; public WorkThread(String name, CountDownLatch cdl) { super(name); this.cdl = cdl; } public void run() { System.out.println(this.getName() + "启动了,时间为" + System.currentTimeMillis()); System.out.println(this.getName() + "我要统计每个sheet的行数"); try { cdl.await(); Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(this.getName() + "执行完了,时间为" + System.currentTimeMillis()); } } private static class sheetThread extends Thread { private CountDownLatch cdl; public sheetThread(String name, CountDownLatch cdl) { super(name); this.cdl = cdl; } public void run() { try { System.out.println(this.getName() + "启动了,时间为" + System.currentTimeMillis()); Thread.sleep(1000); //模拟任务执行耗时 cdl.countDown(); System.out.println(this.getName() + "执行完了,时间为" + System.currentTimeMillis() + " sheet的行数为:" + (int) (Math.random()*100)); } catch (InterruptedException e) { e.printStackTrace(); } } } public static void main(String[] args) throws Exception { CountDownLatch cdl = new CountDownLatch(2); WorkThread wt0 = new WorkThread("WorkThread", cdl ); wt0.start(); sheetThread dt0 = new sheetThread("sheetThread1", cdl); sheetThread dt1 = new sheetThread("sheetThread2", cdl); dt0.start(); dt1.start(); } }
WorkThread启动了,时间为1640054503027 WorkThread我要统计每个sheet的行数 sheetThread1启动了,时间为1640054503028 sheetThread2启动了,时间为1640054503029 sheetThread2执行完了,时间为1640054504031 sheet的行数为:6 sheetThread1执行完了,时间为1640054504031 sheet的行数为:44 WorkThread执行完了,时间为1640054505036
我们继续根据上面的测试案例流程,一步一步的分析CountDownLatch 源码。
public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); }
/** * Synchronization control For CountDownLatch. * Uses AQS state to represent count. */ private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L; Sync(int count) { setState(count); } int getCount() { return getState(); } protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } }
public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); }
public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); if (tryAcquireShared(arg) < 0) doAcquireSharedInterruptibly(arg); }
如果线程中断,抛出异常,否则开始调用 tryAcquireShared(1),其内部类Sync的实现也非常简单,就是判断state也就是CountDownLatch的计数是否等于0,
/** * Acquires in shared interruptible mode. * @param arg the acquire argument */ private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0) { setHeadAndPropagate(node, r); = null; // help GC failed = false; return; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } }
下面就到 countDown()方法了
public void countDown() { sync.releaseShared(1); }
public final boolean releaseShared(int arg) { if (tryReleaseShared(arg)) { doReleaseShared(); return true; } return false; }
protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } }
因此在我们的测试案例中,我们需要调用两次 countDown方法,才会将全局的state更新为0,然后继续执行doReleaseShared()方法。
/** * Release action for shared mode -- signals successor and ensures * propagation. (Note: For exclusive mode, release just amounts * to calling unparkSuccessor of head if it needs signal.) */ private void doReleaseShared() { /* * Ensure that a release propagates, even if there are other * in-progress acquires/releases. This proceeds in the usual * way of trying to unparkSuccessor of head if it needs * signal. But if it does not, status is set to PROPAGATE to * ensure that upon release, propagation continues. * Additionally, we must loop in case a new node is added * while we are doing this. Also, unlike other uses of * unparkSuccessor, we need to know if CAS to reset status * fails, if so rechecking. */ for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) continue; // loop to recheck cases unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) continue; // loop on failed CAS } if (h == head) // loop if head changed break; } }
/** * Wakes up node's successor, if one exists. * * @param node the node */ private void unparkSuccessor(Node node) { /* * If status is negative (i.e., possibly needing signal) try * to clear in anticipation of signalling. It is OK if this * fails or if status is changed by waiting thread. */ int ws = node.waitStatus; if (ws < 0) compareAndSetWaitStatus(node, ws, 0); /* * Thread to unpark is held in successor, which is normally * just the next node. But if cancelled or apparently null, * traverse backwards from tail to find the actual * non-cancelled successor. */ Node s =; if (s == null || s.waitStatus > 0) { s = null; for (Node t = tail; t != null && t != node; t = t.prev) if (t.waitStatus <= 0) s = t; } if (s != null) LockSupport.unpark(s.thread); }
每个线程执行前先通过acquire方法获取信号,执行后通过release归还信号 。每次acquire返回成功后,Semaphore可用的信号量就会减少一个,如果没有可用的信号,
public class SemaphoreTest { public static void main(String[] args) { final Semaphore semaphore = new Semaphore(5); Runnable runnable = () -> { try { semaphore.acquire(); System.out.println(Thread.currentThread().getName() + "获得了信号量>>>>>,时间为" + System.currentTimeMillis()); Thread.sleep(1000); System.out.println(Thread.currentThread().getName() + "释放了信号量<<<<<,时间为" + System.currentTimeMillis()); } catch (InterruptedException e) { e.printStackTrace(); } finally { semaphore.release(); } }; Thread[] threads = new Thread[10]; for (int i = 0; i < threads.length; i++) threads[i] = new Thread(runnable); for (int i = 0; i < threads.length; i++) threads[i].start(); } }
Thread-0获得了信号量>>>>>,时间为1640058647604 Thread-1获得了信号量>>>>>,时间为1640058647604 Thread-2获得了信号量>>>>>,时间为1640058647604 Thread-3获得了信号量>>>>>,时间为1640058647605 Thread-4获得了信号量>>>>>,时间为1640058647605 Thread-0释放了信号量<<<<<,时间为1640058648606 Thread-1释放了信号量<<<<<,时间为1640058648606 Thread-5获得了信号量>>>>>,时间为1640058648607 Thread-4释放了信号量<<<<<,时间为1640058648607 Thread-3释放了信号量<<<<<,时间为1640058648607 Thread-7获得了信号量>>>>>,时间为1640058648607 Thread-8获得了信号量>>>>>,时间为1640058648607 Thread-2释放了信号量<<<<<,时间为1640058648606 Thread-6获得了信号量>>>>>,时间为1640058648607 Thread-9获得了信号量>>>>>,时间为1640058648607 Thread-7释放了信号量<<<<<,时间为1640058649607 Thread-6释放了信号量<<<<<,时间为1640058649607 Thread-8释放了信号量<<<<<,时间为1640058649607 Thread-9释放了信号量<<<<<,时间为1640058649608 Thread-5释放了信号量<<<<<,时间为1640058649607
我们使用for循环同时创建10个线程,首先是线程 0 1 2 3 4获得了信号量,再后面的10行打印结果中,线程1到5分别释放信号量,相同线程间隔也是1000毫秒,
然后线程5 6 7 8 9才能继续获得信号量,而且保持最大获取信号量的线程数小于等于5。
public Semaphore(int permits) { sync = new NonfairSync(permits); }
public Semaphore(int permits, boolean fair) { sync = fair ? new FairSync(permits) : new NonfairSync(permits); }
/** * Synchronization implementation for semaphore. Uses AQS state * to represent permits. Subclassed into fair and nonfair * versions. */ abstract static class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 1192457210091910933L; Sync(int permits) { setState(permits); } final int getPermits() { return getState(); } final int nonfairTryAcquireShared(int acquires) { for (;;) { int available = getState(); int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } protected final boolean tryReleaseShared(int releases) { for (;;) { int current = getState(); int next = current + releases; if (next < current) // overflow throw new Error("Maximum permit count exceeded"); if (compareAndSetState(current, next)) return true; } } final void reducePermits(int reductions) { for (;;) { int current = getState(); int next = current - reductions; if (next > current) // underflow throw new Error("Permit count underflow"); if (compareAndSetState(current, next)) return; } } final int drainPermits() { for (;;) { int current = getState(); if (current == 0 || compareAndSetState(current, 0)) return current; } } }
第12行 getPermits() 方法获取当前的可用的信号量,即还有多少线程可以同时获得信号量
第15行 nonfairTryAcquireShared方法尝试获取共享锁,逻辑就是直接将可用信号量减去该方法请求获取的数量,更新state并返回该值。
第24行 tryReleaseShared 方法尝试释放共享锁,逻辑就是直接将可用信号量加上该方法请求释放的数量,更新state并返回。
/** * Fair version */ static final class FairSync extends Sync { private static final long serialVersionUID = 2014338818796000944L; FairSync(int permits) { super(permits); } protected int tryAcquireShared(int acquires) { for (;;) { if (hasQueuedPredecessors()) return -1; int available = getState(); int remaining = available - acquires; if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } }
看尝试获取共享锁的方法中,多了个 if (hasQueuedPredecessors) 的判断,在java多线程6:ReentrantLock,
public class CyclicBarrierTest extends Thread { private CyclicBarrier cb; private int sleepSecond; public CyclicBarrierTest(CyclicBarrier cb, int sleepSecond) { this.cb = cb; this.sleepSecond = sleepSecond; } public void run() { try { System.out.println(this.getName() + "开始, 时间为" + System.currentTimeMillis()); Thread.sleep(sleepSecond * 1000); cb.await(); System.out.println(this.getName() + "结束, 时间为" + System.currentTimeMillis()); } catch (Exception e) { e.printStackTrace(); } } public static void main(String[] args) { Runnable runnable = new Runnable() { public void run() { System.out.println("CyclicBarrier的barrierAction开始运行, 时间为" + System.currentTimeMillis()); } }; CyclicBarrier cb = new CyclicBarrier(2, runnable); CyclicBarrierTest cbt0 = new CyclicBarrierTest(cb, 3); CyclicBarrierTest cbt1 = new CyclicBarrierTest(cb, 6); cbt0.start(); cbt1.start(); } }
Thread-1开始, 时间为1640069673534 Thread-0开始, 时间为1640069673534 CyclicBarrier的barrierAction开始运行, 时间为1640069679536 Thread-1结束, 时间为1640069679536 Thread-0结束, 时间为1640069679536
看下 CyclicBarrier 的一个更高级的构造函数
public CyclicBarrier(int parties, Runnable barrierAction) { if (parties <= 0) throw new IllegalArgumentException(); this.parties = parties; this.count = parties; this.barrierCommand = barrierAction; }
Runnable barrierAction用于在线程到达屏障时,优先执行barrierAction,方便处理更复杂的业务场景。
/** * Main barrier code, covering the various policies. */ private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException { final ReentrantLock lock = this.lock; lock.lock(); try { final Generation g = generation; if (g.broken) throw new BrokenBarrierException(); if (Thread.interrupted()) { breakBarrier(); throw new InterruptedException(); } int index = --count; if (index == 0) { // tripped boolean ranAction = false; try { final Runnable command = barrierCommand; if (command != null); ranAction = true; nextGeneration(); return 0; } finally { if (!ranAction) breakBarrier(); } } // loop until tripped, broken, interrupted, or timed out for (;;) { try { if (!timed) trip.await(); else if (nanos > 0L) nanos = trip.awaitNanos(nanos); } catch (InterruptedException ie) { if (g == generation && ! g.broken) { breakBarrier(); throw ie; } else { // We're about to finish waiting even if we had not // been interrupted, so this interrupt is deemed to // "belong" to subsequent execution. Thread.currentThread().interrupt(); } } if (g.broken) throw new BrokenBarrierException(); if (g != generation) return index; if (timed && nanos <= 0L) { breakBarrier(); throw new TimeoutException(); } } } finally { lock.unlock(); } }
首先是 ReentrantLock加锁,全局的count值-1,然后判断count是否等于0,如果不等于0,则循环,condition执行await等待,直到触发、中断、中断或超时,
如果count值等于0,先执行 barrierAction线程,然后condition开始唤醒所有等待的线程。