AQS解读及其实践
- 2019 年 10 月 7 日
- 笔记
AQS概述
AQS全称AbstractQueuedSynchronizer,即抽象队列同步器。AQS是用来构建锁或者其他同步组件的基础框架,它使用一个整型的volatile变量state来维护同步状态,通过内置的FIFO队列来完成资源获取线程的排队工作。AQS为一系列同步器依赖于一个单独的原子变量state的同步器提供了一个非常有用的基础。
AQS的设计是基于模板方法模式设计的,子类通过继承AQS并实现它的抽象模板方法来管理同步状态,而这些模板方法内部就是真正管理同步状态的地方(主要有tryAcquire、tryRelease、tryAcquireShared、tryReleaseShared等)。
AQS既可以支持独占锁地,也支持共享锁,这样就可以方便实现不同类型的同步组件如ReentrantLock、ReentrantReadWriteLock和CountDownLatch等。
AQS类使用单个int(32位)来保存同步状态,并暴露出getState、setState以及compareAndSet操作来读取和更新这个同步状态。其中属性state被声明为volatile,并且通过使用CAS指令来实现compareAndSetState,使得当且仅当同步状态拥有一个一致的期望值的时候,才会被原子地设置成新值,这样就达到了同步状态的原子性管理,确保了同步状态的原子性、可见性和有序性。
补充:ReentrantReadWriteLock利用一个32位的int值保存了两个count,前16位存readCount,后16位存writeCount。

AQS核心源码解读
AQS源码中的主要字段
// 同步队列的head节点, 延迟初始化,除了初始化,只能通过setHead方法修改 // 如果head存在,waitStatus一定是CANCELLED private transient volatile Node head; // 同步队列的tail节点,延迟初始化,只能通过enq方法修改 private transient volatile Node tail; // 同步状态 private volatile int state; // 支持CAS private static final Unsafe unsafe = Unsafe.getUnsafe(); private static final long stateOffset; private static final long headOffset; private static final long tailOffset; private static final long waitStatusOffset; private static final long nextOffset;
AQS源码中的主要方法
protected final int getState() { return state; }protected final void setState(int newState) { state = newState; }protected final boolean compareAndSetState(int expect, int update) { return unsafe.compareAndSwapInt(this, stateOffset, expect, update); }// 钩子方法,独占式获取同步状态, 需要子类实现,实现此方法需要查询当前同步状态并 // 判断同步状态是否符合预期,然后再CAS设置同步状态 // 返回值true代表获取成功,false代表获取失败 protected boolean tryAcquire(int arg) { throw new UnsupportedOperationException(); }// 钩子方法,独占式释放同步状态,需要子类实现, // 等待获取同步状态的线程将有机会获取同步状态 // 返回值true代表获取成功,false代表获取失败 protected boolean tryRelease(int arg) { throw new UnsupportedOperationException(); }// 钩子方法,共享式获取同步状态,需要子类实现, // 返回值负数代表获取失败、0代表获取成功但没有剩余资源、 // 正数代表获取成功,还有剩余资源 protected int tryAcquireShared(int arg) { throw new UnsupportedOperationException(); }// 钩子方法,共享式释放同步状态,需要子类实现 // 返回值负数代表获取失败、0代表获取成功但没有剩余资源、 // 正数代表获取成功,还有剩余资源 protected boolean tryReleaseShared(int arg) { throw new UnsupportedOperationException(); }// 模板方法,独占式获取同步状态,如果当前线程获取同步状态成功,则由该方法返回, // 否则会进入同步队列等待,此方法会调用子类重写的tryAcquire方法 public final void acquire(int arg) { if (!tryAcquire(arg) && acquireQueued(addWaiter(Node.EXCLUSIVE), arg)) selfInterrupt(); }// 模板方法,独占式的释放同步状态,该方法会在释放同步状态后, // 将同步队列中的第一个节点包含的线程唤醒 // 此方法会调用子类重写的tryRelease方法 public final boolean release(int arg) { if (tryRelease(arg)) { Node h = head; if (h != null && h.waitStatus != 0) unparkSuccessor(h); return true; } return false; }// 模板方法,共享式的获取同步状态,如果当前系统未获取到同步状态, // 将会进入同步队列等待,同一时刻可以有多个线程获取到同步状态 // 此方法会调用子类重写的tryAcquireShared方法 public final void acquireShared(int arg) { if (tryAcquireShared(arg) < 0) doAcquireShared(arg); }// 模板方法,共享式的释放同步状态 // 此方法会调用子类重写的tryReleaseShared方法 public final boolean releaseShared(int arg) { if (tryReleaseShared(arg)) { doReleaseShared(); return true; } return false; }// 用于将当前线程加入到等待队列的队尾,并返回当前线程所在的结点 private Node addWaiter(Node mode) { Node node = new Node(Thread.currentThread(), mode); // Try the fast path of enq; backup to full enq on failure Node pred = tail; // 尝试将Node放到队尾 if (pred != null) { node.prev = pred; if (compareAndSetTail(pred, node)) { pred.next = node; return node; } } enq(node); return node; }//初始化或自旋CAS直到入队成功 private Node enq(final Node node) { for (;;) { Node t = tail; if (t == null) { // Must initialize if (compareAndSetHead(new Node())) tail = head; } else { node.prev = t; if (compareAndSetTail(t, node)) { t.next = node; return t; } } } }
AQS实现CountDownLatch
CountDownLatch是一个同步工具类,用来协调多个线程之间的同步,CountDownLatch能够使一个线程在等待另外一些线程完成各自工作之后,再继续执行。使用一个计数器进行实现。计数器初始值为线程的数量。当每一个线程完成自己任务后,计数器的值就会减一。当计数器的值为0时,表示所有的线程都已经完成一些任务,然后在CountDownLatch上等待的线程就可以恢复执行接下来的任务。主要常用的方法countDown()方法以及await())方法。
基于AQS实现CountDownLatch
public class MyCountDownLatch {private Sync sync;public MyCountDownLatch(int count) { sync = new Sync(count); }public void countDown() { sync.releaseShared(1); }public void await() { sync.acquireShared(1); }class Sync extends AbstractQueuedSynchronizer { public Sync(int count) { setState(count); }@Override protected int tryAcquireShared(int arg) { // 只有当state变为0时,加锁成功 return getState() == 0 ? 1 : -1; }@Override protected boolean tryReleaseShared(int arg) { for (; ; ) { int c = getState(); if (c == 0) return false; int nextc = c - 1; // 用CAS操作,讲count减一 if (compareAndSetState(c, nextc)) { // 当state=0时,释放锁成功,返回true return nextc == 0; } } } } }// 测试 public class MyCountDownLatchTest { /* 每隔1s开启一个线程,共开启6个线程 若希望6个线程 同时 执行某一操作 可以用CountDownLatch实现 */ public static void test01() throws InterruptedException { MyCountDownLatch ctl = new MyCountDownLatch(6);for (int i = 0; i < 6; i++) { new Thread() { @Override public void run() { ctl.countDown(); ctl.await(); System.out.println("here I am..."); } }.start(); Thread.sleep(1000L); } }/* 开启6个线程,main线程希望6个线程都执行完某个操作后,才执行某个操作 可以用CountDownLatch来实现 */ public static void test02() throws InterruptedException { MyCountDownLatch ctl = new MyCountDownLatch(6);for (int i = 0; i < 6; i++) { new Thread() { @Override public void run() { System.out.println("after print..."); ctl.countDown(); } }.start(); Thread.sleep(1000L); }ctl.await(); System.out.println("main thread do something ..."); }public static void main(String args[]) throws InterruptedException { test01(); } }
AQS实现Semaphore
Semaphore是用来保护一个或者多个共享资源的访问,Semaphore内部维护了一个计数器,其值为可以访问的共享资源的个数。一个线程要访问共享资源,先获得信号量,如果信号量的计数器值大于1,意味着有共享资源可以访问,则使其计数器值减去1,再访问共享资源。Semaphore用来控制同时访问某个特定资源的操作数量,或者同时执行某个指定操作的数量。还可以用来实现某种资源池限制,或者对容器施加边界。常用方法为acquire()方法和release()方法。
基于AQS实现CountDownLatch
public class MySemaphore {private Sync sync;public MySemaphore(int permits) { sync = new Sync(permits); }//抢信号量、就是在加锁 public void acquire() { sync.acquireShared(1); }//释放信号量,就是解锁 public void release() { sync.releaseShared(1); }class Sync extends AbstractQueuedSynchronizer { private int permits;public Sync(int permits) { this.permits = permits; }@Override protected int tryAcquireShared(int arg) { int state = getState(); int nextState = state + arg; // 如果信号量没占满,加锁的个数没有达到permits if (nextState <= permits) { if (compareAndSetState(state, nextState)) { return 1; } } return -1; }@Override protected boolean tryReleaseShared(int arg) { int state = getState(); if (compareAndSetState(state, state - arg)) { return true; } else { return false; } } } }// 测试 public class MySemaphoreTest { static MySemaphore sp = new MySemaphore(6);public static void main(String args[]) { for (int i = 0; i < 1000; i++) { new Thread() { @Override public void run() { try { sp.acquire(); // 抢信号量、就是在加锁 Thread.sleep(2000L); } catch (InterruptedException e) { e.printStackTrace(); } queryDB("localhost:3006"); sp.release(); // 释放信号量,就是解锁 } }.start(); } }public static void queryDB(String url) { System.out.println("query " + url); } }