0%

ThreadLocal学习

一、ThreadLocal解决了什么问题

在一些情况下,我们希望对于某一个共享变量,对于不同线程来说都是独一无二的,例如:有一个共享变量x,线程A去对其进行写操作,读操作,线程B也对其进行写操作读操作,而线程A和线程B在去读写x时的感觉,就像是读写一个本地变量一样,完全与外界封闭, 这就是线程封闭的想法。下面有三种线程封闭的实践:

  • Ad-hoc 线程封闭:

    • 即维护线程封闭性职责交给程序实现,很不推荐使用。
  • 栈封闭:

    • 利用线程的局部变量只有自己可见这一特性,每次都手动将共享变量拷贝一份局部变量的副本。缺点是,只有编写程序的人才知道哪些对象需要封闭,并且还要考虑栈溢出的问题。
  • ThreadLocal(线程封闭的最佳实践):

    • Java提供给我们线程封闭的实现类,下面我们来学习它

二、ThreadLocal学习

1. quick start

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class Basic {
// x看似是一个类共享的静态变量,实际上是线程独占的静态变量!
public static ThreadLocal<Long> x = new ThreadLocal<Long>(){
// 延迟加载的初始化,默认为null
@Override
protected Long initialValue() {
// 初始值设为当前线程的id;
return Thread.currentThread().getId();
}
};

public static void main(String[] args) {
new Thread(){
@Override
public void run() {
System.out.println(x.get());
}
}.start();
System.out.println(x.get());
/**
* 打印结果为:
* 11
* 1
*/
x.set(101l); // x被赋值为101 ps:x仅为当前main线程可见
x.remove(); // x变成了初始值 ps:x仅为当前main线程可见
}
}

2. ThreadLocal的常用场景:

  • 场景1 线程资源持有:

    在这里插入图片描述
    在Web开发中,处理每一个用户的线程都有一个对应的Session,把取出来的用户数据都存放到同一个ThreadLocal的user中,不同线程的user这样就被隔离开了.

  • 场景2 :JDBC中确保线程资源一致性

在这里插入图片描述
在一个线程中可能会有多个JDBC操作, 例如一个线程处理一个用户的请求,那么这一线程中可能就会有多次对数据库的访问操作,每一次操作对应上面的part

ThreadLocal能够让同一个线程申请到的Connection保证是同一个。之所以这样保证线程获取到Connection的一致性,是因为:JDBC层面要保证一个事务是在一个线程里维护的,这样做不仅天然的解决的同步问题,同时也方便回滚的实现,因此,每一个操作所在的那个(事务)线程要和JDBC连接池的线程进行绑定. 这里的ThreadLocalMap是ThreadLocal变量的一个容器。

同时,在Spring的分布式事务里,也是通过ThreadLocal来记录处理事务的上下文.

  • 场景3 线程安全

    在这里插入图片描述
    LastError是共享变量,初始是False,当某一线程运行错误时,会将其置为True, 如果是线程A出错,那么在它去将LastError置为True时应该只能被线程A自己看到,而不能被Thread2看到, 此时,ThreadLocal就可以保证上面的实现.

  • 场景4 分布式计算

    在这里插入图片描述

某一个巨大的任务被切分成多个线程去跑,而每个子线程运行出的子任务结果就可以保存到ThreadLocal变量中,不同的线程只可见自己的运行结果,最终由Collector完成统一组装。

举个例子:

现在有一个任务,x初始为0,通过调用 x = increment(x) 来把x加到10000, increment函数实现如下:

1
2
3
4
5
Integer increment throw Exception(Integer x){
Thread.sleep(1000);
x += 1;
return x;
}

如果我们希望使用100个线程去完成这个任务,一个可行的方法就是,每个线程加100次,然后最后把这100个线程的计算结果给整合到一起. 于是,我们可以建立一个静态HashMap被这个100个线程共享,HashMap中的每一个Entry就是一个<Thread, Integer>的映射, 然后每次在initialValue的时候,把新的Entry给put进去,当在进行结果读取的时候,通过遍历HashMap来获取每一个子任务的结果,最终进行整合.

需要注意的是,上述使用的HashMap在初始化阶段也是在并发环境中的,因此其put操作也必须是原子操作。 其次,真正实现的时候,在构成<Thread, Integer>映射的时,我们希望value是一个引用变量(Val<Integer>)。

三、ThreadLocal的简单实现

  1. MyThreadLocal.java

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    package threadlocal;

    import java.util.HashMap;
    import java.util.concurrent.atomic.AtomicInteger;

    /**
    * @author czf
    * @Date 2020/4/3 10:40 上午
    */
    public class MyThreadLocal<T> {

    static AtomicInteger atomicInteger = new AtomicInteger();
    // h(x) = 1640531527*x % ( INT_MAX ) , INT_MAX = x^{32} - 1
    // 高德纳提出的算法
    Integer threadLocalHash = atomicInteger.getAndAdd(0x61c88647);

    // 之所以是Entry<Integer, Object>而不是是Entry<MyThreadLocal<?>, Object>
    // 是因为如果使用后者的话,HashMap的引用会导致对应
    // ThreadLocal变量永远不能被GC回收
    // 其实在ThreadLocal的源码里,是一个继承了WeakReference的Entry
    // static class Entry extends WeakReference<ThreadLocal>
    // 这里只是简单实现了一下,因此就用一个唯一哈希的Int来代表唯一的ThreadLocal.
    static HashMap<Thread, HashMap<Integer, Object>> mp = new HashMap<>();

    synchronized private HashMap<Integer, Object> getMap(){
    Thread thread = Thread.currentThread();
    if ( !mp.containsKey(this) )
    mp.put(thread, new HashMap<Integer, Object>());
    return mp.get(thread);
    }

    public T get(){
    HashMap<Integer,Object> map = getMap();
    if (map.get(threadLocalHash)==null)
    map.put(threadLocalHash, initialValue());
    return (T) map.get(threadLocalHash);
    }

    public void set(T v){
    HashMap<Integer,Object> map = getMap();
    map.put(threadLocalHash, v);
    }

    private T v;
    protected T initialValue(){
    return null;
    }

    }
  2. Test.java

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30

    import java.util.HashSet;
    import java.util.Set;

    /**
    * @author czf
    * @Date 2020/4/3 10:39 上午
    */
    public class Test {
    static MyThreadLocal<Long> v = new MyThreadLocal<Long>(){
    @Override
    protected Long initialValue() {
    return Thread.currentThread().getId();
    }
    };
    public static void main(String[] args) {
    Set<Long> st = new HashSet<>();
    for(int i=0; i<100; ++i){
    new Thread(){
    @Override
    public void run() {
    st.add(v.get());
    System.out.println(v.get());
    }
    }.start();
    }
    while(st.size()!=100);
    System.out.println("size="+st.size());
    }
    }

    实现的主要思路就是,MyThreadLocal类中有一个能被所有MyThreadLocal对象访问到的一个静态HashMap,里面的Entry为 <Thread, 对应该Thread的局部变量Map>.

    对应该Thread的局部变量Map中,Entry为<能够找到这个局部变量的唯一标识符,其Value值> ; 在源码的实现里,这个唯一标识符就是ThreadLocal变量本身,考虑到其能够被GC回收,源码并没有使用容器里的Map,而是重新实现了一个ThreadLocalMap这个专门存储ThreadLocal变量,其中Entry是继承了WeakReference的一个弱引用,注意,这里只是Key是弱引用,而Value不是。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    static class Entry extends WeakReference<ThreadLocal> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal k, Object v) {
    super(k);
    value = v;
    }
    }

在上面的粗糙实现中,直接用Hash值(atomicInteger)来作为这个唯一表示。

四、使用ThreadLocal需要注意的:

  • 因为ThreadLocalMap的Entry中,Key是弱引用,但Value不是弱引用,因此,在调用完set()操作,一定还要调用对应的remove()操作。如果不这么做的话, 会导致如下问题:

    • 如果该线程是运行在线程池里的线程,那么在下一个请求得到这个线程的时候,可能会发生get到上一次请求set的值
    • 如果Thread实例还在,但是ThreadLocal实例却不在了,则ThreadLocal实例作为key所关联的value无法被外部访问,却还被强引用着,因此出现了内存泄露。
  • ThreadLocalMap实现的HashMap是使用线性探测法来解决冲突的,因此,让一个线程内部的ThreadLocal变量的个数尽可能的小,能够减少冲突次数,保证性能。

五、ThreadLocal原理:

1. 每一个Thread里都有一个Map (ThreadLocalMap)

1
2
3
4
5
6
7
8
9
10
11
12
13
public  class Thread implements Runnable {
...
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;

/*
* InheritableThreadLocal values pertaining to this thread. This map is
* maintained by the InheritableThreadLocal class.
*/
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
...
}

2. ThreadLocal的set操作

  1. 先尝试获取map,如果获取不到就创建
  2. 得到map后就对其进行set
    1
    2
    3
    4
    5
    6
    7
    8
    9
    public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t); // 获取map
    if (map != null) {
    map.set(this, value);
    } else {
    createMap(t, value); // map如果是null就去创建
    }
    }

3. ThreadLocal的get操作

  1. 尝试获取map,如果map没有被初始化就返回初始值
  2. 如果获取到map就返回对应的值
1
2
3
4
5
6
7
8
9
10
11
12
13
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue(); // 在setInitialValue方法里面,会对没有初始化的map进行初始化
}

4. ThreadLocalMap

ThreadLocalMap是每一个WeakHashMap, 每一个线程都有着这样一个Map,并且从上面的代码中可以看出,这个Map是以懒加载的形式进行初始化的。

KV对的类型是< ThreadLocal<?> , Object >

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
* Set the value associated with key.
*
* @param key the thread local object
* @param value the value to be set
*/
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

if (k == null) { // 在这里,
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

六、哪里用到了

这个记录一下自己使用ThreadLocal的场合,在仿写Spring轮子的时候, 要在创建多例的时候需要判断是否出现了循环依赖,因此,会保存正在创建的对象,如果在创建对象的时候发现这个对象已经处于正在创建的对象的集合中,那么就说明出现了多例bean的循环依赖,需要抛出异常,在这里,这个保存正在创建的对象的集合就是一个ThreadLocal的HashSet, 因为如果不是ThreadLocal的话,在线程A正在创建X的时候,X处于正在创建的状态,此时线程B也去创建X,这时发现X已经处于正在创建的状态了,就会误以为出现了循环依赖,这是不对的,因为其他线程创建X对当前线程不应该有影响… 而这里如果使用了ThreadLocal就能够解决这个问题。(因为必须是在单个线程里出现环才说明出现了循环依赖)

参考

ThreadLocal-面试必问深度解析
ThreadLocal小结-到底会不会引起内存泄露