AQS解读及其实践

  • 2019 年 10 月 7 日
  • 筆記

AQS概述

AQS全称AbstractQueuedSynchronizer,即抽象队列同步器。AQS是用来构建锁或者其他同步组件的基础框架,它使用一个整型的volatile变量state来维护同步状态,通过内置的FIFO队列来完成资源获取线程的排队工作。AQS为一系列同步器依赖于一个单独的原子变量state的同步器提供了一个非常有用的基础。

AQS的设计是基于模板方法模式设计的,子类通过继承AQS并实现它的抽象模板方法来管理同步状态,而这些模板方法内部就是真正管理同步状态的地方(主要有tryAcquire、tryRelease、tryAcquireShared、tryReleaseShared等)。

AQS既可以支持独占锁地,也支持共享锁,这样就可以方便实现不同类型的同步组件如ReentrantLock、ReentrantReadWriteLock和CountDownLatch等。

AQS类使用单个int(32位)来保存同步状态,并暴露出getState、setState以及compareAndSet操作来读取和更新这个同步状态。其中属性state被声明为volatile,并且通过使用CAS指令来实现compareAndSetState,使得当且仅当同步状态拥有一个一致的期望值的时候,才会被原子地设置成新值,这样就达到了同步状态的原子性管理,确保了同步状态的原子性、可见性和有序性。

补充:ReentrantReadWriteLock利用一个32位的int值保存了两个count,前16位存readCount,后16位存writeCount。

AQS核心源码解读

AQS源码中的主要字段

// 同步队列的head节点, 延迟初始化,除了初始化,只能通过setHead方法修改  // 如果head存在,waitStatus一定是CANCELLED  private transient volatile Node head;  // 同步队列的tail节点,延迟初始化,只能通过enq方法修改  private transient volatile Node tail;  // 同步状态  private volatile int state;  // 支持CAS  private static final Unsafe unsafe = Unsafe.getUnsafe();  private static final long stateOffset;  private static final long headOffset;  private static final long tailOffset;  private static final long waitStatusOffset;  private static final long nextOffset;

AQS源码中的主要方法

    protected final int getState() {  return state;  }protected final void setState(int newState) {  state = newState;  }protected final boolean compareAndSetState(int expect, int update) {  return unsafe.compareAndSwapInt(this, stateOffset, expect, update);  }// 钩子方法,独占式获取同步状态, 需要子类实现,实现此方法需要查询当前同步状态并  // 判断同步状态是否符合预期,然后再CAS设置同步状态  // 返回值true代表获取成功,false代表获取失败  protected boolean tryAcquire(int arg) {  throw new UnsupportedOperationException();  }// 钩子方法,独占式释放同步状态,需要子类实现,  // 等待获取同步状态的线程将有机会获取同步状态  // 返回值true代表获取成功,false代表获取失败  protected boolean tryRelease(int arg) {  throw new UnsupportedOperationException();  }// 钩子方法,共享式获取同步状态,需要子类实现,  // 返回值负数代表获取失败、0代表获取成功但没有剩余资源、  // 正数代表获取成功,还有剩余资源  protected int tryAcquireShared(int arg) {  throw new UnsupportedOperationException();  }// 钩子方法,共享式释放同步状态,需要子类实现  // 返回值负数代表获取失败、0代表获取成功但没有剩余资源、  // 正数代表获取成功,还有剩余资源  protected boolean tryReleaseShared(int arg) {  throw new UnsupportedOperationException();  }// 模板方法,独占式获取同步状态,如果当前线程获取同步状态成功,则由该方法返回,  // 否则会进入同步队列等待,此方法会调用子类重写的tryAcquire方法  public final void acquire(int arg) {  if (!tryAcquire(arg) &&  acquireQueued(addWaiter(Node.EXCLUSIVE), arg))  selfInterrupt();  }// 模板方法,独占式的释放同步状态,该方法会在释放同步状态后,  // 将同步队列中的第一个节点包含的线程唤醒  // 此方法会调用子类重写的tryRelease方法  public final boolean release(int arg) {  if (tryRelease(arg)) {  Node h = head;  if (h != null && h.waitStatus != 0)  unparkSuccessor(h);  return true;  }  return false;  }// 模板方法,共享式的获取同步状态,如果当前系统未获取到同步状态,  // 将会进入同步队列等待,同一时刻可以有多个线程获取到同步状态  // 此方法会调用子类重写的tryAcquireShared方法  public final void acquireShared(int arg) {  if (tryAcquireShared(arg) < 0)  doAcquireShared(arg);  }// 模板方法,共享式的释放同步状态  // 此方法会调用子类重写的tryReleaseShared方法  public final boolean releaseShared(int arg) {  if (tryReleaseShared(arg)) {  doReleaseShared();  return true;  }  return false;  }// 用于将当前线程加入到等待队列的队尾,并返回当前线程所在的结点  private Node addWaiter(Node mode) {  Node node = new Node(Thread.currentThread(), mode);  // Try the fast path of enq; backup to full enq on failure  Node pred = tail;    // 尝试将Node放到队尾  if (pred != null) {  node.prev = pred;  if (compareAndSetTail(pred, node)) {  pred.next = node;  return node;  }  }  enq(node);  return node;  }//初始化或自旋CAS直到入队成功  private Node enq(final Node node) {  for (;;) {  Node t = tail;  if (t == null) { // Must initialize  if (compareAndSetHead(new Node()))  tail = head;  } else {  node.prev = t;  if (compareAndSetTail(t, node)) {  t.next = node;  return t;  }  }  }  }

AQS实现CountDownLatch

CountDownLatch是一个同步工具类,用来协调多个线程之间的同步,CountDownLatch能够使一个线程在等待另外一些线程完成各自工作之后,再继续执行。使用一个计数器进行实现。计数器初始值为线程的数量。当每一个线程完成自己任务后,计数器的值就会减一。当计数器的值为0时,表示所有的线程都已经完成一些任务,然后在CountDownLatch上等待的线程就可以恢复执行接下来的任务。主要常用的方法countDown()方法以及await())方法。

基于AQS实现CountDownLatch

public class MyCountDownLatch {private Sync sync;public MyCountDownLatch(int count) {  sync = new Sync(count);  }public void countDown() {  sync.releaseShared(1);  }public void await() {  sync.acquireShared(1);  }class Sync extends AbstractQueuedSynchronizer {  public Sync(int count) {  setState(count);  }@Override  protected int tryAcquireShared(int arg) {  // 只有当state变为0时,加锁成功  return getState() == 0 ? 1 : -1;  }@Override  protected boolean tryReleaseShared(int arg) {  for (; ; ) {  int c = getState();  if (c == 0) return false;  int nextc = c - 1;  // 用CAS操作,讲count减一  if (compareAndSetState(c, nextc)) {  // 当state=0时,释放锁成功,返回true  return nextc == 0;  }  }  }  }  }// 测试  public class MyCountDownLatchTest {  /*  每隔1s开启一个线程,共开启6个线程  若希望6个线程 同时 执行某一操作  可以用CountDownLatch实现  */  public static void test01() throws InterruptedException {  MyCountDownLatch ctl = new MyCountDownLatch(6);for (int i = 0; i < 6; i++) {  new Thread() {  @Override  public void run() {  ctl.countDown();  ctl.await();  System.out.println("here I am...");  }  }.start();  Thread.sleep(1000L);  }  }/*  开启6个线程,main线程希望6个线程都执行完某个操作后,才执行某个操作  可以用CountDownLatch来实现  */  public static void test02() throws InterruptedException {  MyCountDownLatch ctl = new MyCountDownLatch(6);for (int i = 0; i < 6; i++) {  new Thread() {  @Override  public void run() {  System.out.println("after print...");  ctl.countDown();  }  }.start();  Thread.sleep(1000L);  }ctl.await();  System.out.println("main thread do something ...");  }public static void main(String args[]) throws InterruptedException {  test01();  }  }

AQS实现Semaphore

Semaphore是用来保护一个或者多个共享资源的访问,Semaphore内部维护了一个计数器,其值为可以访问的共享资源的个数。一个线程要访问共享资源,先获得信号量,如果信号量的计数器值大于1,意味着有共享资源可以访问,则使其计数器值减去1,再访问共享资源。Semaphore用来控制同时访问某个特定资源的操作数量,或者同时执行某个指定操作的数量。还可以用来实现某种资源池限制,或者对容器施加边界。常用方法为acquire()方法和release()方法。

基于AQS实现CountDownLatch

public class MySemaphore {private Sync sync;public MySemaphore(int permits) {  sync = new Sync(permits);  }//抢信号量、就是在加锁  public void acquire() {  sync.acquireShared(1);  }//释放信号量,就是解锁  public void release() {  sync.releaseShared(1);  }class Sync extends AbstractQueuedSynchronizer {  private int permits;public Sync(int permits) {  this.permits = permits;  }@Override  protected int tryAcquireShared(int arg) {  int state = getState();  int nextState = state + arg;  // 如果信号量没占满,加锁的个数没有达到permits  if (nextState <= permits) {  if (compareAndSetState(state, nextState)) {  return 1;  }  }  return -1;  }@Override  protected boolean tryReleaseShared(int arg) {  int state = getState();  if (compareAndSetState(state, state - arg)) {  return true;  } else {  return false;  }  }  }  }// 测试  public class MySemaphoreTest {  static MySemaphore sp = new MySemaphore(6);public static void main(String args[]) {  for (int i = 0; i < 1000; i++) {  new Thread() {  @Override  public void run() {  try {  sp.acquire(); // 抢信号量、就是在加锁  Thread.sleep(2000L);  } catch (InterruptedException e) {  e.printStackTrace();  }  queryDB("localhost:3006");  sp.release(); // 释放信号量,就是解锁  }  }.start();  }  }public static void queryDB(String url) {  System.out.println("query " + url);  }  }