# ThreadLocal

# 什么是ThreadLocal

ThreadLocal为每一个线程所独有,提供了将共享内容改为线程私有化的方式,从而保证了线程安全。

# ThreadLocal灵魂拷问

  • ThreadLocal的key是弱引用,那么在 threadLocal.get()的时候,发生GC之后,key是否为null?
  • ThreadLocal中ThreadLocalMap的数据结构?
  • ThreadLocalMap的Hash算法?
  • ThreadLocalMap中Hash冲突如何解决?
  • ThreadLocalMap扩容机制?
  • ThreadLocalMap中过期key的清理机制?探测式清理和启发式清理流程?
  • ThreadLocalMap.set()方法实现原理?
  • ThreadLocalMap.get()方法实现原理?
  • 项目中ThreadLocal使用情况?遇到的坑?

# ThreadLocal的实现原理

每个Thread中维护了一个ThreadLocalMap属性,ThreadLocalMap中使用数组存储了当前线程所有的ThreadLocal对象。

ThreadLocalMap的结构类似与HashMap,但不存在链表结构只有数组,当hash冲突时,就继续计算下一个hash位置,直到数组当前位置为null。

# 为什么会导致内存泄漏

原因有两点

  1. 如果Thread一直存在,并且对ThreadLocal存在强引用关系,因此其存放的值不会释放,会一直占用着内存。此类情况有可能导致内存释放(不容易发生,除非代码逻辑过多,存放的值过大)。
  2. 如果Thread一直存在,对ThreadLocal无强引用关系(本身是WeakReference),则GC时key会被回收,但是value仍然存在(JDK的设计上会在继续使用其相关方法时清理key为null的entry,但不使用时仍然造成了一段时间的内存泄漏。),依然有内存泄漏的风险。此类情况在使用线程池时较为常见。

如何解决

在使用完之后调用remove()

# InheritableThreadLocal

使用ThreadLocal时,异步场景下无法给子线程共享父线程中的线程副本数据,JDK提供了InheritableThreadLocal来解决这个问题。

其实现原理是父线程通过new Thread()创建子线程时,Thread#init方法在Thread的构造方法中被调用。在init方法中会将父线程的InheritableThreadLocal数据拷贝到子线程中,但这种拷贝只有在创建子线程时才会进行一次,线程池的方式是不适用的。可以用阿里开源的TransmittableThreadLocal组件来解决线程池的场景。

# 源码解读

从简单的使用ThreadLocal进入源码

public static void main(String[] args) {
    ThreadLocal<Integer> tl = new ThreadLocal<>();
    // ThreadLocal使用就只需要以下三个方法
    // 设值
    tl.set(5);
    // 获取值
    tl.get();
    // 移除
    tl.remove();
}

ThreadLocal类的重要属性

// ThreadLocal对象的HashCode
private final int threadLocalHashCode = nextHashCode();

// 使用原子类记录下一个HashCode,注意这是一个静态类,所有ThreadLocal都共享
private static AtomicInteger nextHashCode = new AtomicInteger();

// 它是斐波那契数,也叫做黄金分割数,可以让hash分布非常均匀
private static final int HASH_INCREMENT = 0x61c88647;

// 每new一个ThreadLocal实例,hashCode递增固定间隔
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

ThreadLocal的set、get、remove方法都比较简单,主要逻辑都在ThreadLocalMap中

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

接下来看ThreadLocalMap的set代码

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 将hashCode通过与运算确定在数组中的位置
    int i = key.threadLocalHashCode & (len-1);
    
    // 这个循环,在hash冲突时会通过nextIndex方法线性向后查找下一个槽位
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 若两个ThreadLocal对象相等,则更新value
        if (k == key) {
            e.value = value;
            return;
        }
        // k若没有被强引用,有可能会被回收,但value不会,因此需要执行清理操作(替换过期的数据)
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    
    // 通过上面循环后,当前i位置肯定为null的槽位,创建一个entry占领槽位
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 若没有清理到槽位,且当前table中存在的Entry大于threshold(table.length的2/3)
    // 则进行rehash操作
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

// 替换过期的数据
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;
	// 从staleSlot开始,向数组左边遍历,不停更新过期槽位的下标,直到遇到null
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 进入当前方法是staleSlot的Entry的key==null,现在往staleSlot后面寻找,因为一开始staleSlot位置不是过期的数据,所以对应的ThreadLocal可能会被放置到这之后了,因此需要找到并将其更新到前面的staleSlot位置。
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        
        ThreadLocal<?> k = e.get();
		// 在staleSlot后面找到相等的ThreadLocal
        if (k == key) {
            // 更新value
            e.value = value;
			// 将当前存在的Entry更新到staleSlot槽位上
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // 若staleSlot前面还有过期的Entry,则从它开始清理
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 环状循环一遍后,遍历到初始位置后,刷新清理的位置为初始位置
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 经过上面逻辑后,staleSlot一定是可以替换的位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // staleSlot前面可能还存在过期的Entry,进行清理
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

// 向后遍历少数槽位(只遍历n的二进制位数的次数),检查是否存在stale entry,进行清理
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 更新为len,重置清理的次数
            n = len;
            // 只要清理成功一个,removed就标记为true
            removed = true;
            // i被更新为下一个不为null的下标
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

// 删除过期的Entry
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    // 删除staleSlot槽位上的Entry
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    Entry e;
    int i;
    // 向后遍历,直到null
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // 遇到过期槽位,将其置空
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 计算槽位上的hashCode,看是否与下标对应,不对应说明set时是往后面线性查找null存储过
            // 由于前面有null的槽位了,可以将Entry重新从一次hashCode位置开始找到第一个为null的槽位存放
            // 也就是重新整理staleSlot后面的槽位
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    // 返回下一个不为null的槽位下标
    return i;
}

private void rehash() {
    expungeStaleEntries();
    // 使用原threshold*3/4的阈值来判断,也就是降低threshold,因为前一步可能会清理掉一些stale entry
    // 能让其在清理不足量的场景下,继续rehash。
    if (size >= threshold - threshold / 4)
        resize();
}

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    // 遍历table,遇到stale entry就清理掉
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    // 扩容为原来的两倍
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
    setThreshold(newLen);
    size = count;
    table = newTab;
}

# refer

万字图文深度解析ThreadLocal (opens new window)

修改于: 8/11/2022, 3:17:56 PM