ThreadLocal源码阅读

一、ThreadLocal概述

        在Java开发中,多线程是一个永远绕不开的话题。Java服务器中通常使用一组线程处理一个会话或一个连接,这一组线程一般具有父子关系,它们往往需要客户端传递的用户信息来完成业务逻辑,这些信息仅针对本次连接或会话有效。当然我们可以以参数的形式传递这些数据,但当参数异常复杂的时候,层间传递就变得很枯燥,有没有一种方法可以简化这种场景?当然有,他就是ThreadLocal技术,ThreadLocal是一种线程局部变量共享机制,它独立于语言,但本次仅讨论Java语言范畴。

        使用ThreadLocal维护的变量在每个线程内部维护一个副本,对该变量的修改仅对该线程可见(指ThreadLocal.set(obj),对象成员修改除外),同样可以在线程内部的任意位置使用它。这样,既满足了线程之间的隔离要求,又减少了参数的传递工作,何乐而不为之!另外,我们上面提到,如果是一组线程(这一组线程具有父子关系)需要变量共享呢?没问题,Java给我们提供了InheritableThreadLocal,只要我们在父线程中设置了相关变量,子线程会自动继承这些变量的值,但本质上是子线程初始化自己的副本时使用了父线程的值,此后对各自副本的修改(指ThreadLocal.set(obj),对象成员修改除外)也仅在当前线程生效

        注意,网上很多博客写着ThreadLocal是用来解决多线程并发问题的,这种理解在我看来是错误的。多线程并发指的是多个线程对同一个临界区操作的互斥问题,或多个线程之间的同步问题,而ThreadLocal在每个线程都有一个副本,不存在互斥,也不存在同步,因此跟多线程并发问题无关。

二、ThreadLocal运用

上面简单介绍了ThreadLocal的性质,下面来看一下具体的运用场景:

  1. Session管理:正如我们开篇提到的用户信息传递问题,本质上就是一个Session管理,它在一次会话开始的时候创建ThreadLocal变量保存所有全局信息,会话结束的时候释放ThreadLocal。如果线程的生命周期与会话的生命周期一致,则可以不用手动释放ThreadLocal变量,如果使用了线程池就必须在提交任务时手动初始化ThreadLocal,结束任务时手动清理ThreadLocal保存的数据,否则就可能使用的前一个会话遗留的脏数据。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    @RestController
    public class UserController {
    private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>();
    private static final Executor EXECUTOR = Executors.newFixedThreadPool(10);

    @Autowired
    private UserService userService;

    @RequestMapping("/user")
    public User login(UserInfo userInfo) {
    USER_INFO.set(userInfo);
    User user = userService.login();
    EXECUTOR.execute(()->{
    //注意线程池需要手动管理ThreadLocal
    USER_INFO.set(userInfo);
    userService.doSomething();
    USER_INFO.remove();
    });
    return user;
    }
    }
  2. 连接管理:当我们有一个线程池用来处理一些远程任务,每个任务都需要与远程主机建立连接,那么为了减少频繁建立连接带来的性能开销,我们可以使用ThreadLocal来保存这些连接,使之与线程的生命周期一致,这样就避免了频繁建立远程连接带来的开销。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    @RestController
    public class UserController {
    private static final ThreadLocal<UserInfo> USER_INFO = new ThreadLocal<>();
    private static final ThreadLocal<RemoteConnection> REMOTE_CONNECTION = new ThreadLocal<>();
    private static final Executor EXECUTOR = Executors.newFixedThreadPool(10);

    @Autowired
    private UserService userService;

    @RequestMapping("/user")
    public User login(UserInfo userInfo) {
    USER_INFO.set(userInfo);
    User user = userService.login();
    EXECUTOR.execute(()->{
    //注意线程池需要手动管理ThreadLocal
    USER_INFO.set(userInfo);
    //仅需要第一个任务初始化connection
    RemoteConnection connection = REMOTE_CONNECTION.get();
    if(connection == null) {
    connection = RemoteUtil.getConnection();
    REMOTE_CONNECTION.set(connection);
    }
    userService.doSomething(connection);
    USER_INFO.remove();
    });
    return user;
    }
    }
  3. ThreadLocal在Spring事务管理中的应用

三、ThreadLocal原理

3.1 ThreadLocal原理

        事实上ThreadLocal本身并不存储数据,它只是数据的 管家 。在Thread内部有threadLocals(对应ThreadLocal)和inheritableThreadLocals(对应InheritableThreadLocal)两个Map(它们也是hash map,但有自己的实现),它们维护者ThreadLocal/InheritableThreadLocal的副本。下面以threadLocals为例,threadLocals类型是ThreadLocal.ThreadLocalMap,它定义在ThreadLocal类中,其key是ThreadLocal变量的弱引用,value是对应副本值。

        对于同一个ThreadLocal在不同Thread中,threadLocals中的key是同一个对象,这确保了同一个threadLocal变量能检索到 同一个类型 的value,但value在不同线程之间是独立的。

        对于不同ThreadLocal在同一个线程中来说,不同的Thread Local对应 不同类型 的值,也就是threadLocals中的多个entry。

ThreadLocal与多线程之间的关系

        ThreadLocalMap中的key继承了WeakReference,因为ThreadLocal对于用户而言就是一个普通变量,它的生命周期应当符合一般变量行为。如果这里是强引用,那么即便用户将其引用置为null(或者方法返回、对象回收等等),该ThreadLocal对象可能依然无法被回收,因为还有其他线程的threadLocalMap中的entry对其有强引用。

        将Entry中的key设为弱引用即可解决ThreadLocal GC回收的问题,但对应value又会带来内存泄露,对于value而言依然有thread -> threadLocalMap -> entry -> value这样的引用链存在,且该value永远无法被访问,直到线程结束。为解决这一问题,该Map中新增了部分对”stale entry”的回收逻辑。

ThreadLocal弱引用

3.2 ThreadLocal源码

点击展开代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
//一个变量对应一个ThreadLocal,该变量在不同线程之间有不同的副本
public class ThreadLocal<T> {
//每个线程都有一个threadLocalMap变量,其key为一个ThreadLocal变量,value为我们设置的值
//这里的threadLocalHashCode就充当了key的hashCode,它就是threadLocalMap中用于计算hash地址的依据
//ThreadLocal中计算hashCode的方法很简单,就是从0开始不断累加HASH_INCREMENT
//为保证不同ThreadLocal对象之间HashCode不同,其不断累加的临时变量nextHashCode是static的
private final int threadLocalHashCode = nextHashCode();

//用于计算threadLocalHashCode的【静态】临时变量,它保存的是下一个ThreadLocal变量的Hash地址
//并每创建一个ThreadLocal对象更新一次
private static AtomicInteger nextHashCode = new AtomicInteger();

//不断累加的HashCode差值,这个值可以保证在2^N范围类hashcode均匀分布,保证检索效率。
//该值与黄金分割和斐波那契散列算法有关,参考:https://zhuanlan.zhihu.com/p/40515974
private static final int HASH_INCREMENT = 0x61c88647;

//计算nextHashCode并返回当前对象的hashCode
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}

//当某个线程下对应的ThreadLocal变量未设置时(threadLocalMap对应的k-v不存在)
//就会调用该方法返回一个初始值,默认返回null,可以继承重写该方法
protected T initialValue() {
return null;
}

//返回一个具有指定初始化supplier的ThreadLocal对象
//它会使用supplier获取(在initialValue方法中调用)初始值
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

//默认无参构造方法
public ThreadLocal() {
}

//返回该ThreadLocal变量对应于当前线程的副本值
public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//获取当前线程的ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
//获取当前threadLocal对应的entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
//当前threadLocal对应的value
T result = (T)e.value;
return result;
}
}
//当前threadLocal对应值不存在就初始化,并返回初始化的值
return setInitialValue();
}

//如果当前线程的ThreadLocalMap为null或者当前ThreadLocal值不存在,
//就初始化map(如果需要的话)和当前threadLocal对应的k-v,并返回初始值
//注意ThreadLocalMap虽然在Thread中,但Thread并不初始化,全部交由ThreadLocal管理
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
//map存在就设置初始值
if (map != null)
map.set(this, value);
//map不存在就创建map并初始化
else
createMap(t, value);
return value;
}

//设置当前线程对应ThreadLocal的值
//过程与setInitialValue一致
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

//移除当前线程对应于该ThreadLocal的副本(建议不使用threadLocal的时候手动释放)
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}

//获取线程关联的threadLocalMap
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

//创建一个ThreadLocalMap并使用给定初始化参数(this->firstValue)初始化
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

//创建一个从父线程继承下来的ThreadLocalMap,该Map包含父线程ThreadLocalMap的全部值
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}

//InheritableThreadLocal中使用
T childValue(T parentValue) {
throw new UnsupportedOperationException();
}

//用supplier作为初始化器的ThreadLocal实现类
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

//用户提供的初始化器
private final Supplier<? extends T> supplier;

SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}

//通过supplier获取初始值
@Override
protected T initialValue() {
return supplier.get();
}
}

//ThreadLocalMap的实现,仅用于保存ThreadLocal变量
static class ThreadLocalMap {

//Entry定义,注意key的弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

//初始容量
private static final int INITIAL_CAPACITY = 16;

//hash桶
private Entry[] table;

//存储的entity个数
private int size = 0;

//扩容阈值
private int threshold; // Default to 0

//设置扩容阈值为2/3
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

//使用线性探查法寻找下一个地址
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

//使用线性探查法寻找上一个地址(回收脏entry的时候使用)
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

//根据给定的初始值,初始化ThreadLocalMap(至少有一个值才会初始化)
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

//拷贝构造
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
//使用线性探查法解决冲突,和hashMap不一样
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

//根据key获取entry
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
//注意这里key的比较是==,不是equals,因为不同线程存的threadLocal就是同一个对象的引用
if (e != null && e.get() == key)
return e;
else //线性探查寻找entry
return getEntryAfterMiss(key, i, e);
}

//线性探查寻找entry
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
//找到了
if (k == key)
return e;
//key==null表示该value是一个泄露值(永远无法被使用),需要进行回收
if (k == null)
expungeStaleEntry(i);
else//继续探查下一个位置
i = nextIndex(i, len);
e = tab[i];
}//没有找到返回null
return null;
}

//添加一个键值对
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

//key存在直接替换值
if (k == key) {
e.value = value;
return;
}

//如果当前选定位置是一个“stale entry”,则按照一定的算法插入entry
//并清理一定范围内的“stale entry”
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
//如果冲突了就再次返回调用nextIndex探查下一个地址
}

tab[i] = new Entry(key, value);
int sz = ++size;
//如果???
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

//移除一个元素(不会导致"stale entry")
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear(); //清除value
//清理脏entry
expungeStaleEntry(i);
return;
}
}
}

// 插入新entry并清理周围"stale entry"
// 新的entry一定在staleSlot位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

//回溯查找泄露entry的最小位置,直到遇见一个未占用的hash桶为止
//中途记录下泄露entry的最小位置,作为清理的起始点
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null;
i = prevIndex(i, len)) {
if (e.get() == null)
slotToExpunge = i;
}

//向后查找泄露entry的最大位置,直到遇见一个未占用的hash桶或找到一个可替换的entry为止
//并在合适的时机插入新entry
//注意:从i到下一个未占用的hash桶之间是必须要遍历的,这样可以检查我们要插入的key先前
//是否出现过,如果出现过就必须替换value,这样才能保证key的唯一性
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

//如果发现有可替换的entry
if (k == key) {
e.value = value; //替换value
//把当前entry换到staleSlot处,注意staleSlot是该key第一个可用位置
//交换既满足了key的唯一性,又尽可能保证key的检索效率
tab[i] = tab[staleSlot];
tab[staleSlot] = e;

//如果一直没有找到脏entry,就以当前位置为起点清理
//注意,当前位置是和staleSlot交换过的,所以当前位置一定是一个脏entry
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//清理脏entry
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

//如果前向搜索没有发现脏enyrt,并且当前节点是脏entry,就以当前位置为清理的起始点
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}// end for

//如果没有找到可替换的entry,将staleSlot位置释放,并重新填充一个新entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

//如果遍历过程中发现了其他脏entry,则清理之
//slotToExpunge一直是清理的起始点,起始点尽可能小,清理范围尽可能大
if (slotToExpunge != staleSlot)
//两次清理
//第一次:清理slotToExpunge到下一个null slot之间的脏entry
//第二次:从下一个null的下一个位置开始,最少扫描log2(len)次
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

//清理stale entry
//参数:staleSlot 当前stale entry的位置
//返回:下一个空hash桶(slot)的位置,这个区间内所有的脏entry都会被清理掉
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 清理当前指定位置staleSlot上的stale entry
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

//向后遍历直到下一个空桶位置
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//遇见stale entry则清理之
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//如果遇见正常的entry就重新计算地址
//因为前面释放了stale entry,当前entry很大可能有更优的存储位置
//减少e的线性探查次数,提升访问效率
int h = k.threadLocalHashCode & (len - 1);
//e的hash地址不在当前位置,证明e一定是因为hash地址冲突而放到了这里
//现在释放了之前部分stale entry,e很可能有更优秀的位置
//如果h==i,则当前位置就是e的最佳位置,不做任何操作
if (h != i) {
//释放当前位置
tab[i] = null;
//从h向后重新给e找位置
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
//返回下一个空slot的位置
return i;
}

//从i的下一个位置开始清理
//n控制扫描次数,扫描log2(n)轮,每一轮根据情况扫描一段或一个,整体时间复杂度为nlog2(n)
//兼顾扫描效率和清理效果
//如果有stale entry被删除则返回true
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
//如果发现stale entry
if (e != null && e.get() == null) {
n = len;
removed = true;
//移除[i, next null slot]之间的脏entry,i为next null slot位置
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

//rehash
private void rehash() {
expungeStaleEntries();
// size >= 5 / 12(len) ?
if (size >= threshold - threshold / 4)
resize();
}

//扩容,新容量为旧容量的两倍
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

//遍历旧table(将旧table中的内容移动到新table)
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
//如果还遇到stale entry,直接unlink
if (k == null) {
e.value = null; // Help the GC
}
//计算e在新表中的位置,并放入其中
else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

//更新threshold
setThreshold(newLen);
size = count;
table = newTab;
}

// 清理所有脏entry,并重新计算每一个有效元素的最佳位置
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}
文章作者: Jack.Charles
文章链接: https://blog.zjee.me/2020/04/19/thread-local/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 江影不沉浮