Golang 实现 Redis(3): 实现内存数据库

  • 2020 年 3 月 29 日
  • 筆記

本文是 golang 实现 redis 系列的第三篇, 主要介绍如何实现内存KV数据库。本文完整源代码在作者Github: HDT3213/godis

db.go 是内存数据库的主要源文件,db.Exec 方法会从协议解析器中获得命令参数并调用相应的处理函数进行处理。

目录:

Concurrent Hash Map

KV 内存数据库的核心是并发安全的哈希表,常见的设计有几种:

  • sync.map: golang 官方提供的并发哈希表, 性能优秀但结构复杂不便于扩展

  • juc.ConcurrentHashMap: java 的并发哈希表采用分段锁实现。在进行扩容时访问哈希表线程都将协助进行 rehash 操作,在 rehash 结束前所有的读写操作都会阻塞。因为缓存数据库中键值对数量巨大且对读写操作响应时间要求较高,使用juc的策略是不合适的。

  • memcached hashtable: 在后台线程进行 rehash 操作时,主线程会判断要访问的哈希槽是否已被 rehash 从而决定操作 old_hashtable 还是操作 primary_hashtable。
    这种策略使主线程和rehash线程之间的竞争限制在哈希槽内,最小化rehash操作对读写操作的影响,这是最理想的实现方式。但由于作者才疏学浅无法使用 golang 实现该策略故忍痛放弃(主要原因在于 golang 没有 volatile 关键字, 保证线程可见性的操作非常复杂),欢迎各位读者讨论。

本文采用在 sync.map 发布前 golang 社区广泛使用的分段锁策略。我们将key分散到固定数量的 shard 中避免 rehash 操作。shard 是有锁保护的 map, 当 shard 进行 rehash 时会阻塞shard内的读写,但不会对其他 shard 造成影响。

这种策略简单可靠易于实现,但由于需要两次 hash 性能略差。这个 dict 完整源码在Github 可以独立使用(虽然也没有什么用。。。)。

定义数据结构:

type ConcurrentDict struct {      table []*Shard      count int32  }    type Shard struct {      m     map[string]interface{}      mutex sync.RWMutex  }  

在构造时初始化 shard,这个操作相对比较耗时:

func computeCapacity(param int) (size int) {  	if param <= 16 {  		return 16  	}  	n := param - 1  	n |= n >> 1  	n |= n >> 2  	n |= n >> 4  	n |= n >> 8  	n |= n >> 16  	if n < 0 {  		return math.MaxInt32  	} else {  		return int(n + 1)  	}  }    func MakeConcurrent(shardCount int) *ConcurrentDict {      shardCount = computeCapacity(shardCount)      table := make([]*Shard, shardCount)      for i := 0; i < shardCount; i++ {          table[i] = &Shard{              m: make(map[string]interface{}),          }      }      d := &ConcurrentDict{          count: 0,          table: table,      }      return d  }  

哈希算法选择FNV算法:

const prime32 = uint32(16777619)    func fnv32(key string) uint32 {      hash := uint32(2166136261)      for i := 0; i < len(key); i++ {          hash *= prime32          hash ^= uint32(key[i])      }      return hash  }  

定位shard, 当n为2的整数幂时 h % n == (n – 1) & h

func (dict *ConcurrentDict) spread(hashCode uint32) uint32 {  	if dict == nil {  		panic("dict is nil")  	}  	tableSize := uint32(len(dict.table))  	return (tableSize - 1) & uint32(hashCode)  }    func (dict *ConcurrentDict) getShard(index uint32) *Shard {  	if dict == nil {  		panic("dict is nil")  	}  	return dict.table[index]  }  

Get 和 Put 方法实现:

func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) {  	if dict == nil {  		panic("dict is nil")  	}  	hashCode := fnv32(key)  	index := dict.spread(hashCode)  	shard := dict.getShard(index)  	shard.mutex.RLock()  	defer shard.mutex.RUnlock()  	val, exists = shard.m[key]  	return  }    func (dict *ConcurrentDict) Len() int {  	if dict == nil {  		panic("dict is nil")  	}  	return int(atomic.LoadInt32(&dict.count))  }    // return the number of new inserted key-value  func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) {  	if dict == nil {  		panic("dict is nil")  	}  	hashCode := fnv32(key)  	index := dict.spread(hashCode)  	shard := dict.getShard(index)  	shard.mutex.Lock()  	defer shard.mutex.Unlock()    	if _, ok := shard.m[key]; ok {  		shard.m[key] = val  		return 0  	} else {  		shard.m[key] = val  		dict.addCount()  		return 1  	}  }  

LockMap

上一节实现的ConcurrentMap 可以保证对单个 key 操作的并发安全性,但是仍然无法满足需求:

  1. MSETNX 命令当且仅当所有给定键都不存在时所有给定键设置值, 因此我们需要锁定所有给定的键直到完成所有键的检查和设置
  2. LPOP 命令移除列表中最后一个元素后需要移除该键值对,因此我们锁定该键直到移除元素并移除空列表

因此我们需要实现 db.Locker 用于锁定一个或一组 key 并在我们需要的时候释放锁。

实现 db.Locker 最直接的想法是使用一个 map[string]*sync.RWMutex, 加锁过程分为两步: 初始化对应的锁 -> 加锁, 解锁过程也分为两步: 解锁 -> 释放对应的锁。那么存在一个无法解决的并发问题:

时间 协程A 协程B
1 locker["a"].Unlock()
2 locker["a"] = &sync.RWMutex{}
3 delete(locker["a"])
4 locker["a"].Lock()

由于 t3 时协程B释放了锁,t4 时协程A试图加锁会失败。

若我们在解锁时不释放锁就可以避免该异常的发生,但是每个曾经使用过的锁都无法释放从而造成严重的内存泄露。

我们注意到哈希表的长度远少于可能的键的数量,反过来说多个键可以共用一个哈希槽。若我们不为单个键加锁而是为它所在的哈希槽加锁,因为哈希槽的数量非常少即使不释放锁也不会占用太多内存。

作者根据这种思想实现了 LockerMap 来解决并发控制问题。

type Locks struct {      table []*sync.RWMutex  }    func Make(tableSize int) *Locks {      table := make([]*sync.RWMutex, tableSize)      for i := 0; i < tableSize; i++ {          table[i] = &sync.RWMutex{}      }      return &Locks{          table: table,      }  }    func (locks *Locks)Lock(key string) {      index := locks.spread(fnv32(key))      mu := locks.table[index]      mu.Lock()  }    func (locks *Locks)UnLock(key string) {      index := locks.spread(fnv32(key))      mu := locks.table[index]      mu.Unlock()  }  

哈希算法已经在Dict一节介绍过不再赘述。

在锁定多个key时需要注意,若协程A持有键a的锁试图获得键b的锁,此时协程B持有键b的锁试图获得键a的锁则会形成死锁。

解决方法是所有协程都按照相同顺序加锁,若两个协程都想获得键a和键b的锁,那么必须先获取键a的锁后获取键b的锁,这样就可以避免循环等待。

func (locks *Locks)Locks(keys ...string) {      keySlice := make(sort.StringSlice, len(keys))      copy(keySlice, keys)      sort.Sort(keySlice)      for _, key := range keySlice {          locks.Lock(key)      }  }    func (locks *Locks)RLocks(keys ...string) {      keySlice := make(sort.StringSlice, len(keys))      copy(keySlice, keys)      sort.Sort(keySlice)      for _, key := range keySlice {          locks.RLock(key)      }  }  

TTL

Time To Live (TTL) 的实现方式非常简单,其核心是 string -> time 哈希表。

当访问某个 key 时会检查是否过期,并删除过期key:

func (db *DB) Get(key string) (*DataEntity, bool) {  	db.stopWorld.RLock()  	defer db.stopWorld.RUnlock()    	raw, ok := db.Data.Get(key)  	if !ok {  		return nil, false  	}  	if db.IsExpired(key) {  		return nil, false  	}  	entity, _ := raw.(*DataEntity)  	return entity, true  }    func (db *DB) IsExpired(key string) bool {  	rawExpireTime, ok := db.TTLMap.Get(key)  	if !ok {  		return false  	}  	expireTime, _ := rawExpireTime.(time.Time)  	expired := time.Now().After(expireTime)  	if expired {  		db.Remove(key)  	}  	return expired  }  

同时会定时的检查过期key并删除:

func (db *DB) CleanExpired() {  	now := time.Now()  	toRemove := &List.LinkedList{}  	db.TTLMap.ForEach(func(key string, val interface{}) bool {  		expireTime, _ := val.(time.Time)  		if now.After(expireTime) {  			// expired  			db.Data.Remove(key)  			toRemove.Add(key)  		}  		return true  	})  	toRemove.ForEach(func(i int, val interface{}) bool {  		key, _ := val.(string)  		db.TTLMap.Remove(key)  		return true  	})  }    func (db *DB) TimerTask() {  	ticker := time.NewTicker(db.interval)  	go func() {  		for range ticker.C {  			db.CleanExpired()  		}  	}()  }