上文中结尾处,我们说到了现在很少用Hashtable,那么在需要线程安全的场景中,我们如何保持同步呢,这就是本文的重点:ConcurrentHashMap(JDK1.7)。ConcurrentHashMap比HashMap以及Hashtable复杂多了,其内部采用了锁分段技术用以提高并发存取效率。我们看一下测试代码:
代码清单1:
- import java.util.HashMap;
- import java.util.Hashtable;
- import java.util.Map;
- import java.util.concurrent.ConcurrentHashMap;
- public class CurrentHashMapTest {
- private static ConcurrentHashMap < String,
- String > concurrentHashMap = new ConcurrentHashMap < >();
- private static Hashtable < String,
- String > hashtable = new Hashtable < >();
- private static HashMap < String,
- String > hashMap = new HashMap < >();
- public static void main(String[] args) {
- testConcurrentHashMapThreadSafe();
- System.out.println(concurrentHashMap.size() + "last:" + concurrentHashMap.get("concurrentHashMap9999"));
- testHashtableThreadSafe();
- System.out.println(hashtable.size() + "last:" + hashtable.get("hashtable9999"));
- testHashMapThreadSafe();
- System.out.println(hashMap.size() + "last:" + hashMap.get("hashmap9999"));
- System.out.println("test end");
- }
- public static void testConcurrentHashMapThreadSafe() {
- long startTime = System.currentTimeMillis();
- for (int i = 0; i < 100000; i++) {
- new ConcurrentThread(i, "concurrentHashMap", concurrentHashMap).start();
- }
- long endTime = System.currentTimeMillis();
- System.out.println("ConcurrentHashMap take time:" + (endTime - startTime));
- }
- public static void testHashtableThreadSafe() {
- long startTime = System.currentTimeMillis();
- for (int i = 0; i < 100000; i++) {
- new ConcurrentHashTableThread(i, "hashtable", hashtable).start();
- }
- long endTime = System.currentTimeMillis();
- System.out.println("Hashtable take time:" + (endTime - startTime));
- }
- public static void testHashMapThreadSafe() {
- System.out.println("enter test HashMap");
- long startTime = System.currentTimeMillis();
- for (int i = 0; i < 100000; i++) {
- new ConcurrentHashMapThread(i, "hashmap", hashMap).start();
- }
- long endTime = System.currentTimeMillis();
- System.out.println(" HashMap take time:" + (endTime - startTime));
- }
- }
- class ConcurrentThread extends Thread {
- public int i;
- public String name;
- private ConcurrentHashMap < String,
- String > map;
- public ConcurrentThread(int i, String name, ConcurrentHashMap < String, String > map) {
- this.i = i;
- this.name = name;
- this.map = map;
- }
- @Override public void run() {
- super.run();
- map.put(name + i, i + "");
- }
- }
- class ConcurrentHashTableThread extends Thread {
- public int i;
- public String name;
- private Hashtable < String,
- String > map;
- public ConcurrentHashTableThread(int i, String name, Hashtable < String, String > map) {
- this.i = i;
- this.name = name;
- this.map = map;
- }
- @Override public void run() {
- super.run();
- map.put(name + i, i + "");
- }
- }
- class ConcurrentHashMapThread extends Thread {
- public int i;
- public String name;
- private HashMap < String,
- String > map;
- public ConcurrentHashMapThread(int i, String name, HashMap < String, String > map) {
- this.i = i;
- this.name = name;
- this.map = map;
- }
- @Override public void run() {
- super.run();
- map.put(name + i, i + "");
- }
- }
上面的代码输出结果(代码运行环境:Ubuntu14.04+idea+jdk1.7):
ConcurrentHashMap take time:3522
100000last:9999
Hashtable take time:3674
100000last:9999
enter test HashMap
HashMap take time:1105168
99945last:9999
test end
从代码输出结果上可以看出ConcurrentHashMap的效率明显要比Hashtable要高效,而HashMap是不安全的。
先说一下ConcurrentHashMap的内部结构,如下图所示:
按照以前的风格,我们看下ConcurrentHashMap的构造函数,如代码清单2:
- static final int DEFAULT_INITIAL_CAPACITY = 16;//table数组的默认长度,这个和HashMap是一样的
- static final float DEFAULT_LOAD_FACTOR = 0.75f;//加载因子
- static final int DEFAULT_CONCURRENCY_LEVEL = 16;//并发级别
- static final int MAXIMUM_CAPACITY = 1 << 30;//最大容量,这里可以看到DEFAULT_INITIAL_CAPACITY、DEFAULT_LOAD_FACTOR、MAXIMUM_CAPACITY,都是和HashMap相应字段的值是相同的。
- static final int MIN_SEGMENT_TABLE_CAPACITY = 2;//段组的最小长度,这里最小值为2的原因是,如果小于2的话(即为1),就没有锁分段的意义了,就和Hashtable一样了,不能两个线程同时并发存和取数据了。
- static final int MAX_SEGMENTS = 1 << 16; //段组的最大长度
- static final int RETRIES_BEFORE_LOCK = 2;//
- final int segmentMask;//地位掩码
- final int segmentShift;//段偏移量
- final Segment<K,V>[] segments;//段组
- transient Set<K> keySet;
- transient Set<Map.Entry<K,V>> entrySet;
- transient Collection<V> values;
- public ConcurrentHashMap(int initialCapacity,
- float loadFactor, int concurrencyLevel) {
- if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
- throw new IllegalArgumentException();
- if (concurrencyLevel > MAX_SEGMENTS)
- concurrencyLevel = MAX_SEGMENTS;
- // Find power-of-two sizes best matching arguments
- int sshift = 0;//左移次数
- int ssize = 1;//经过计算得到段组的长度
- while (ssize < concurrencyLevel) {//我们在阅读源码时,碰到这类代码,我们可以假设输入值,以便更好的理解代码的含义。
- ++sshift;
- ssize <<= 1;//sszie的值为2的sshift幂
- }
- this.segmentShift = 32 - sshift;//
- this.segmentMask = ssize - 1;//低位掩码,sszie为2的指数,则segmentMask的低位全是1.
- if (initialCapacity > MAXIMUM_CAPACITY)
- initialCapacity = MAXIMUM_CAPACITY;
- int c = initialCapacity / ssize;
- if (c * ssize < initialCapacity)
- ++c;
- int cap = MIN_SEGMENT_TABLE_CAPACITY;
- while (cap < c)//cap的值是2的指数,同时计算之后也是table数组的容量。
- cap <<= 1;
- // create segments and segments[0]
- Segment<K,V> s0 =
- new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
- (HashEntry<K,V>[])new HashEntry[cap]);
- Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];//创建段组
- UNSAFE.putOrderedObject(ss, SBASE, s0); // 利用Unsafe将s0放在SBASE放入位置
- this.segments = ss;
- }
- public ConcurrentHashMap(int initialCapacity, float loadFactor) {
- this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
- }
- public ConcurrentHashMap(int initialCapacity) {
- this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
- }
- public ConcurrentHashMap() {
- this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
- }
- public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
- this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
- DEFAULT_INITIAL_CAPACITY),
- DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
- putAll(m);
- }
代码清单2中的34~41行,主要是为了计算segmentShift与segmentMask的值,下面举个两个计算过程的例子:
看了上面的两组运行数据,我们可以知道segmentShift以及segmentMask的值是由concurrentLevel决定的,这几个变量意义在代码注释里都有说明,这里就不进行阐述了。
我们创建ConcurrentHashMap对象的目的就是为了使用,于是我们就来到了put方法这里,如代码清单3
代码清单3
- public V put(K key, V value) {
- Segment < K,
- V > s;
- if (value == null) throw new NullPointerException(); //ConcurrentHashMap也不能接收null的键值对的,key和value都不能为Null
- int hash = hash(key); //计算哈希值
- int j = (hash >>> segmentShift) & segmentMask; //计算段组的索引,(hash>>>segmentShift)保留哈希值的高位将其结果与segmentMask与是为了求段组下标。
- if ((s = (Segment < K, V > ) UNSAFE.getObject // nonvolatile; recheck
- (segments, (j << SSHIFT) + SBASE)) == null) //取出(j<<SSHIFT)+SBASE内存偏移处的对象,如果为空,则创建。
- s = ensureSegment(j);
- return s.put(key, hash, value, false); //具体的put数据的操作由segment对象来完成。
- }
- private int hash(Object k) { //这个hash函数的作用就是为了对key的hashcode的原始值进行再次处理,以减少碰撞。
- int h = hashSeed;
- if ((0 != h) && (k instanceof String)) {
- return sun.misc.Hashing.stringHash32((String) k);
- }
- h ^= k.hashCode();
- // Spread bits to regularize both segment and index locations,
- // using variant of single-word Wang/Jenkins hash.
- h += (h << 15) ^ 0xffffcd7d;
- h ^= (h >>> 10);
- h += (h << 3);
- h ^= (h >>> 6);
- h += (h << 2) + (h << 14);
- return h ^ (h >>> 16);
- }
- private Segment < K,
- V > ensureSegment(int k) {
- final Segment < K,
- V > [] ss = this.segments;
- long u = (k << SSHIFT) + SBASE; //内存地址
- Segment < K,
- V > seg;
- if ((seg = (Segment < K, V > ) UNSAFE.getObjectVolatile(ss, u)) == null) { //如果内存偏移处没有值,使用ss[0]元素为原型。
- Segment < K,
- V > proto = ss[0]; // use segment 0 as prototype
- int cap = proto.table.length; //复制容量
- float lf = proto.loadFactor; //复制加载因子
- int threshold = (int)(cap * lf); //阀值
- HashEntry < K,
- V > [] tab = (HashEntry < K, V > []) new HashEntry[cap];
- if ((seg = (Segment < K, V > ) UNSAFE.getObjectVolatile(ss, u)) == null) { // 再次检查是否为null
- Segment < K,
- V > s = new Segment < K,
- V > (lf, threshold, tab); //创建Segment对象
- while ((seg = (Segment < K, V > ) UNSAFE.getObjectVolatile(ss, u)) == null) { //循环检查u地址偏移处的对象是否为null
- if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s)) //如果赋值成功则跳出循环,
- break;
- }
- }
- }
- return seg; //最终返回此次创建的Segment对象或者u处的Segment对象。
- }
- // Unsafe mechanics
- private static final sun.misc.Unsafe UNSAFE;
- private static final long SBASE;
- private static final int SSHIFT; //有多少个1位
- private static final long TBASE;
- private static final int TSHIFT; //有多少个1位
- private static final long HASHSEED_OFFSET;
- private static final long SEGSHIFT_OFFSET;
- private static final long SEGMASK_OFFSET;
- private static final long SEGMENTS_OFFSET;
- static {
- int ss,
- ts;
- try {
- UNSAFE = sun.misc.Unsafe.getUnsafe();
- Class tc = HashEntry[].class;
- Class sc = Segment[].class;
- TBASE = UNSAFE.arrayBaseOffset(tc); //table组的对象头的偏移量
- SBASE = UNSAFE.arrayBaseOffset(sc); //段组的对象头的偏移量
- ts = UNSAFE.arrayIndexScale(tc); //单个HashEntry的大小,
- ss = UNSAFE.arrayIndexScale(sc); //单个Segment的大小
- HASHSEED_OFFSET = UNSAFE.objectFieldOffset(ConcurrentHashMap.class.getDeclaredField("hashSeed")); //hashSeed的内存地址
- SEGSHIFT_OFFSET = UNSAFE.objectFieldOffset(ConcurrentHashMap.class.getDeclaredField("segmentShift")); //segmentShift的内存地址
- SEGMASK_OFFSET = UNSAFE.objectFieldOffset(ConcurrentHashMap.class.getDeclaredField("segmentMask")); //segmentMask的内存地址
- SEGMENTS_OFFSET = UNSAFE.objectFieldOffset(ConcurrentHashMap.class.getDeclaredField("segments")); //segment的起始地址
- } catch(Exception e) {
- throw new Error(e);
- }
- if ((ss & (ss - 1)) != 0 || (ts & (ts - 1)) != 0) //这里可以看到对于ss以及ts的要求也是2的指数值。
- throw new Error("data type scale not a power of two");
- SSHIFT = 31 - Integer.numberOfLeadingZeros(ss); //numberOfLeadingZeros是代表一个int型的二进制值代表数值的最高位为1的之前有多少个0位。也就是说SSHIFT与TSHIFT代表数据的有效信息占用多少位。
- TSHIFT = 31 - Integer.numberOfLeadingZeros(ts);
- }
通过代码清单3我们知道了ConcurrentHashMap的put操作是由Segment来完成的,下面我们继续往下挖,看代码清单4
代码清单4
- static final class Segment < K,
- V > extends ReentrantLock implements Serializable { //继承ReetrantLock可重入锁
- private static final long serialVersionUID = 2249069246763182397L;
- static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
- transient volatile HashEntry < K,
- V > [] table; //表组
- transient int count; //表的长度
- transient int modCount; //修改次数
- transient int threshold; //阀值
- final float loadFactor; //负载因子
- Segment(float lf, int threshold, HashEntry < K, V > [] tab) {
- this.loadFactor = lf;
- this.threshold = threshold;
- this.table = tab;
- }
- final V put(K key, int hash, V value, boolean onlyIfAbsent) { //put操作
- HashEntry < K,
- V > node = tryLock() ? null: scanAndLockForPut(key, hash, value); //保证能够获取到段锁,只有key不在该段内,node才不为null,其余情况node为null
- V oldValue;
- try {
- HashEntry < K,
- V > [] tab = table;
- int index = (tab.length - 1) & hash; //计算table数组的索引
- HashEntry < K,
- V > first = entryAt(tab, index);
- for (HashEntry < K, V > e = first;;) {
- if (e != null) { //循环遍历链表,如果没有找到e=null然后跳转至else的分支代码中。
- K k;
- if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
- oldValue = e.value;
- if (!onlyIfAbsent) {
- e.value = value; ++modCount;
- }
- break;
- }
- e = e.next;
- } else {
- if (node != null) node.setNext(first); //头插法
- else node = new HashEntry < K,
- V > (hash, key, value, first); //头插法
- int c = count + 1;
- if (c > threshold && tab.length < MAXIMUM_CAPACITY) //扩容处理
- rehash(node);
- else setEntryAt(tab, index, node); ++modCount;
- count = c;
- oldValue = null;
- break;
- }
- }
- } finally {
- unlock();
- }
- return oldValue;
- }
- @SuppressWarnings("unchecked") private void rehash(HashEntry < K, V > node) { //这个函数的理解还是不容易的。
- HashEntry < K,
- V > [] oldTable = table;
- int oldCapacity = oldTable.length;
- int newCapacity = oldCapacity << 1; //扩容方式为old*2.
- threshold = (int)(newCapacity * loadFactor); //新的阀值
- HashEntry < K,
- V > [] newTable = (HashEntry < K, V > []) new HashEntry[newCapacity];
- int sizeMask = newCapacity - 1;
- for (int i = 0; i < oldCapacity; i++) {
- HashEntry < K,
- V > e = oldTable[i]; //遍历table数组,进而遍历单链表
- if (e != null) {
- HashEntry < K,
- V > next = e.next;
- int idx = e.hash & sizeMask;
- if (next == null) // Single node on list
- newTable[idx] = e;
- else { // Reuse consecutive sequence at same slot
- HashEntry < K,
- V > lastRun = e;
- int lastIdx = idx;
- for (HashEntry < K, V > last = next; last != null; last = last.next) { //遍历单链表
- int k = last.hash & sizeMask;
- if (k != lastIdx) {
- lastIdx = k;
- lastRun = last;
- }
- }
- newTable[lastIdx] = lastRun;
- // Clone remaining nodes
- for (HashEntry < K, V > p = e; p != lastRun; p = p.next) {
- V v = p.value;
- int h = p.hash;
- int k = h & sizeMask;
- HashEntry < K,
- V > n = newTable[k];
- newTable[k] = new HashEntry < K,
- V > (h, p.key, v, n);
- }
- }
- }
- }
- int nodeIndex = node.hash & sizeMask; // add the new node
- node.setNext(newTable[nodeIndex]);
- newTable[nodeIndex] = node;
- table = newTable;
- }
- private HashEntry < K,
- V > scanAndLockForPut(K key, int hash, V value) {
- HashEntry < K,
- V > first = entryForHash(this, hash); //根据hash值找到table的数组元素
- HashEntry < K,
- V > e = first;
- HashEntry < K,
- V > node = null;
- int retries = -1; // 用来定位节点,如果为0则定位到包含key的节点
- while (!tryLock()) { //循环检测锁,如果当前线程已经获取到锁,则跳出循环。
- HashEntry < K,
- V > f; // to recheck first below
- if (retries < 0) { //检索key的节点
- if (e == null) {
- if (node == null) // speculatively create node
- node = new HashEntry < K,
- V > (hash, key, value, null);
- retries = 0;
- } else if (key.equals(e.key)) retries = 0;
- else e = e.next;
- } else if (++retries > MAX_SCAN_RETRIES) {
- lock();
- break;
- } else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {
- e = first = f; // re-traverse if entry changed
- retries = -1;
- }
- }
- return node;
- }
- private void scanAndLock(Object key, int hash) {
- // similar to but simpler than scanAndLockForPut
- HashEntry < K,
- V > first = entryForHash(this, hash);
- HashEntry < K,
- V > e = first;
- int retries = -1;
- while (!tryLock()) {
- HashEntry < K,
- V > f;
- if (retries < 0) {
- if (e == null || key.equals(e.key)) retries = 0;
- else e = e.next;
- } else if (++retries > MAX_SCAN_RETRIES) {
- lock();
- break;
- } else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {
- e = first = f;
- retries = -1;
- }
- }
- }
- final V remove(Object key, int hash, Object value) {
- if (!tryLock()) scanAndLock(key, hash);
- V oldValue = null;
- try {
- HashEntry < K,
- V > [] tab = table;
- int index = (tab.length - 1) & hash;
- HashEntry < K,
- V > e = entryAt(tab, index);
- HashEntry < K,
- V > pred = null;
- while (e != null) {
- K k;
- HashEntry < K,
- V > next = e.next;
- if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
- V v = e.value;
- if (value == null || value == v || value.equals(v)) {
- if (pred == null) setEntryAt(tab, index, next);
- else pred.setNext(next); ++modCount; --count;
- oldValue = v;
- }
- break;
- }
- pred = e;
- e = next;
- }
- } finally {
- unlock();
- }
- return oldValue;
- }
- final boolean replace(K key, int hash, V oldValue, V newValue) {
- if (!tryLock()) scanAndLock(key, hash);
- boolean replaced = false;
- try {
- HashEntry < K,
- V > e;
- for (e = entryForHash(this, hash); e != null; e = e.next) {
- K k;
- if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
- if (oldValue.equals(e.value)) {
- e.value = newValue; ++modCount;
- replaced = true;
- }
- break;
- }
- }
- } finally {
- unlock();
- }
- return replaced;
- }
- final V replace(K key, int hash, V value) {
- if (!tryLock()) scanAndLock(key, hash);
- V oldValue = null;
- try {
- HashEntry < K,
- V > e;
- for (e = entryForHash(this, hash); e != null; e = e.next) {
- K k;
- if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
- oldValue = e.value;
- e.value = value; ++modCount;
- break;
- }
- }
- } finally {
- unlock();
- }
- return oldValue;
- }
- final void clear() {
- lock();
- try {
- HashEntry < K,
- V > [] tab = table;
- for (int i = 0; i < tab.length; i++) setEntryAt(tab, i, null); ++modCount;
- count = 0;
- } finally {
- unlock();
- }
- }
- }
上面的代码清单4其实就是Segment类的代码,之前我们说过ConcurrentHashMap的put操作是由Segment的put来执行的。细心的读者可以看到Segment继承了ReentrantLock,也就是其内部是可以直接使用lock与unlock来进行同步操作的。从代码中我们可以看到其put操作是线程安全的,而且Segment的其他成员函数也是线程安全的。这里如果认真看了代码清单2,3,4的同学会发现segments数组的长度取决于构造函数指定的concurrencyLevel的值,在存储数据时并不会扩容segments的数组长度,在进行存储数据时,扩容的是segment的成员变量table数组的长度。
存储数据的姿势搞清楚之后,我们就看看怎么取我们的数据,请看代码清单5:
代码清单5
- public V get(Object key) {
- Segment < K,
- V > s; // manually integrate access methods to reduce overhead
- HashEntry < K,
- V > [] tab;
- int h = hash(key);
- long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; //计算索引
- if ((s = (Segment < K, V > ) UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) { //通过CAS获索引处Segment对象,并进一步获得table的引用
- for (HashEntry < K, V > e = (HashEntry < K, V > ) UNSAFE.getObjectVolatile(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); //找到table索引处的单链表,并循环遍历
- e != null; e = e.next) {
- K k;
- if ((k = e.key) == key || (e.hash == h && key.equals(k))) return e.value;
- }
- }
- return null;
- }
代码清单5没有什么可以过多的说的,就是定位索引,遍历单链表,找到返回对应值,否则返回null.如果大家明白了put的过程,get操作是很好理解的。
接下来我们看下ConcurrentHashMap是怎么统计目前包含多少键值对的,请看代码清单6:
代码清单6
- public int size() {
- // Try a few times to get accurate count. On failure due to
- // continuous async changes in table, resort to locking.
- final Segment < K,
- V > [] segments = this.segments;
- int size;
- boolean overflow; // 是否溢出
- long sum; // 修改次数
- long last = 0L; // 上遍历时的修改次数
- int retries = -1;
- try {
- for (;;) {
- if (retries++==RETRIES_BEFORE_LOCK) { // 这里注意只有可重锁的次数大于最大值时,才会对segments数组元素依次上锁
- for (int j = 0; j < segments.length; ++j) ensureSegment(j).lock(); // force creation
- }
- sum = 0L;
- size = 0;
- overflow = false;
- for (int j = 0; j < segments.length; ++j) {
- Segment < K,
- V > seg = segmentAt(segments, j);
- if (seg != null) {
- sum += seg.modCount;
- int c = seg.count;
- if (c < 0 || (size += c) < 0) //如果相加为负数,则说明已经超过最大值,溢出,即overflow为true
- overflow = true;
- }
- }
- if (sum == last) //如果为true则代表,没有在累积键值对时,没有其他线程改变数据结构,则退出循环。
- break;
- last = sum;
- }
- } finally {
- if (retries > RETRIES_BEFORE_LOCK) { //解锁
- for (int j = 0; j < segments.length; ++j) segmentAt(segments, j).unlock();
- }
- }
- return overflow ? Integer.MAX_VALUE: size;
- }
上面的size函数首先不加锁循环执行以下操作:遍历segments数组元素,获得count和modCount的值并相加。如果连续两次所有的modcount相加结果相等,即last==sum,则过程中没有发生其他线程修改ConcurrentHashMap的情况,返回获得的值。当循环次数超过可重入最大值时,这时需要对所有的段组元素进行加锁,获取返回值后再依次解锁。值得注意的是,加锁过程中要强制创建所有的Segment,否则容易出现其他线程创建Segment并进行put,remove等操作。
来源: https://juejin.im/post/5a190b09f265da430b7aef0a