【学习笔记】深入理解ThreadLocal

347次阅读  |  发布于3年以前

目录

一 引言

ThreadLocal的官方API解释为:

 * This class provides thread-local variables.  These variables differ from
 * their normal counterparts in that each thread that accesses one (via its
 * {@code get} or {@code set} method) has its own, independently initialized
 * copy of the variable.  {@code ThreadLocal} instances are typically private
 * static fields in classes that wish to associate state with a thread (e.g.,
 * a user ID or Transaction ID).

这个类提供线程局部变量。这些变量与正常的变量不同,每个线程访问一个(通过它的get或set方法)都有它自己的、独立初始化的变量副本。ThreadLocal实例通常是类中的私有静态字段,希望将状态与线程关联(例如,用户ID或事务ID)。

1、当使用ThreadLocal维护变量时,ThreadLocal为每个使用该变量的线程提供独立的变量副本,
        所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本
2、使用ThreadLocal通常是定义为 private static ,更好是 private final static
3、Synchronized用于线程间的数据共享,而ThreadLocal则用于线程间的数据隔离
4、ThreadLocal类封装了getMap()、Set()、Get()、Remove()4个核心方法

从表面上来看ThreadLocal内部是封闭了一个Map数组,来实现对象的线程封闭,map的key就是当前的线程id,value就是我们要存储的对象。

实际上是ThreadLocal的静态内部类ThreadLocalMap为每个Thread都维护了一个数组table,hreadLocal确定了一个数组下标,而这个下标就是value存储的对应位置,继承自弱引用,用来保存ThreadLocal和Value之间的对应关系,之所以用弱引用,是为了解决线程与ThreadLocal之间的强绑定关系,会导致如果线程没有被回收,则GC便一直无法回收这部分内容。

二 源码剖析

2.1 ThreadLocal
    //set方法
    public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //实际存储的数据结构类型
        ThreadLocalMap map = getMap(t);
        //判断map是否为空,如果有就set当前对象,没有创建一个ThreadLocalMap
        //并且将其中的值放入创建对象中
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    //get方法 
    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //实际存储的数据结构类型
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //传入了当前线程的ID,到底层Map Entry里面去取
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }   

    //remove方法
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);//调用ThreadLocalMap删除变量
     }

       //ThreadLocalMap中getEntry方法
      private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        } 

   //getMap()方法
   ThreadLocalMap getMap(Thread t) {
    //Thread中维护了一个ThreadLocalMap
        return t.threadLocals;
    }

    //setInitialValue方法
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    //createMap()方法
   void createMap(Thread t, T firstValue) {
   //实例化一个新的ThreadLocalMap,并赋值给线程的成员变量threadLocals
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

从上面源码中我们看到不管是 set() get() remove() 他们都是操作ThreadLocalMap这个静态内部类的,每一个新的线程Thread都会实例化一个ThreadLocalMap并赋值给成员变量threadLocals,使用时若已经存在threadLocals则直接使用已经存在的对象

ThreadLocal.get()

ThreadLocal.set()

ThreadLocal.remove()

2.2 ThreadLocalMap

ThreadLocalMap是ThreadLocal的一个内部类

static class ThreadLocalMap {

         /**    
         * 自定义一个Entry类,并继承自弱引用
         * 同时让ThreadLocal和储值形成key-value的关系
         * 之所以用弱引用,是为了解决线程与ThreadLocal之间的强绑定关系
         * 会导致如果线程没有被回收,则GC便一直无法回收这部分内容
         * 
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * Entry数组的初始化大小(初始化长度16,后续每次都是2倍扩容)
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 根据需要调整大小
         * 长度必须是2的N次幂
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         * table中的个数
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         * 下一个要调整大小的大小值(扩容的阈值)
         */
        private int threshold; // Default to 0

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         * 根据长度计算扩容阈值
         * 保持一定的负债系数
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * Increment i modulo len
         * nextIndex:从字面意思我们可以看出来就是获取下一个索引
         * 获取下一个索引,超出长度则返回
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         * 返回上一个索引,如果-1为负数,返回长度-1的索引
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /**
         * ThreadLocalMap构造方法
         * ThreadLocalMaps是延迟构造的,因此只有在至少要放置一个节点时才创建一个
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            //内部成员数组,INITIAL_CAPACITY值为16的常量
            table = new Entry[INITIAL_CAPACITY];
            //通过threadLocalHashCode(HashCode) & (长度-1)的位运算,确定键值对的位置
            //位运算,结果与取模相同,计算出需要存放的位置
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            // 创建一个新节点保存在table当中
            table[i] = new Entry(firstKey, firstValue);
            //设置table元素为1
            size = 1;
            //根据长度计算扩容阈值
            setThreshold(INITIAL_CAPACITY);
        }

        /**
         * 构造一个包含所有可继承ThreadLocals的新映射,只能createInheritedMap调用
         * ThreadLocal本身是线程隔离的,一般来说是不会出现数据共享和传递的行为
         */
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        /**
         * ThreadLocalMap中getEntry方法
         */
        private Entry getEntry(ThreadLocal<?> key) {
            //通过hashcode确定下标
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            //如果找到则直接返回
            if (e != null && e.get() == key)
                return e;
            else
             // 找不到的话接着从i位置开始向后遍历,基于线性探测法,是有可能在i之后的位置找到的
                return getEntryAfterMiss(key, i, e);
        }


        /**
         * ThreadLocalMap的set方法
         */
        private void set(ThreadLocal<?> key, Object value) {
           //新开一个引用指向table
            Entry[] tab = table;
            //获取table长度
            int len = tab.length;
            ////获取索引值,threadLocalHashCode进行一个位运算(取模)得到索引i
            int i = key.threadLocalHashCode & (len-1);
            /**
            * 遍历tab如果已经存在(key)则更新值(value)
            * 如果该key已经被回收失效,则替换该失效的key
            **/
            //
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }
                //如果 k 为null,则替换当前失效的k所在Entry节点
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //如果上面没有遍历成功则创建新值
            tab[i] = new Entry(key, value);
            // table内元素size自增
            int sz = ++size;
            //满足条件数组扩容x2
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

        /**
         * remove方法
         * 将ThreadLocal对象对应的Entry节点从table当中删除
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();//将引用设置null,方便GC回收
                    expungeStaleEntry(i);//从i的位置开始连续段清理工作
                    return;
                }
            }
        }

        /**
        * ThreadLocalMap中replaceStaleEntry方法
         */
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            // 新建一个引用指向table
            Entry[] tab = table;
            //获取table的长度
            int len = tab.length;
            Entry e;


            // 记录当前失效的节点下标
            int slotToExpunge = staleSlot;

           /**
             * 通过prevIndex(staleSlot, len)可以看出,由staleSlot下标向前扫描
             * 查找并记录最前位置value为null的下标
             */
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // nextIndex(staleSlot, len)可以看出,这个是向后扫描
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                 // 获取Entry节点对应的ThreadLocal对象
                ThreadLocal<?> k = e.get();

                  //如果和新的key相等的话,就直接赋值给value,替换i和staleSlot的下标
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // 如果之前的元素存在,则开始调用cleanSomeSlots清理
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                     /**
                     *在调用cleanSomeSlots()    清理之前,会调用
                     *expungeStaleEntry()从slotToExpunge到table下标所在为
                     *null的连续段进行一次清理,返回值就是table为null的下标
                     *然后以该下标 len进行一次启发式清理
                     * 最终里面的方法实际上还是调用了expungeStaleEntry
                      * 可以看出expungeStaleEntry方法是ThreadLocal核心的清理函数
                     */
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 如果在table中没有找到这个key,则直接在当前位置new Entry(key, value)
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // 如果有其他过时的节点正在运行,会将它们进行清除,slotToExpunge会被重新赋值
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

        /**
         * expungeStaleEntry() 启发式地清理被回收的Entry
         * 有两个地方调用到这个方法
         * 1、set方法,在判断是否需要resize之前,会清理并rehash一遍
         * 2、替换失效的节点时候,也会进行一次清理
        */
          private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                //判断如果Entry对象不为空
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //调用该方法进行回收,
                    //对 i 开始到table所在下标为null的范围内进行一次清理和rehash
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }  

        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }



        /**
         * Re-pack and/or re-size the table. First scan the entire
         * table removing stale entries. If this doesn't sufficiently
         * shrink the size of the table, double the table size.
         */
        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         * 对table进行扩容,因为要保证table的长度是2的幂,所以扩容就扩大2倍
         */
        private void resize() {
        //获取旧table的长度
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            //创建一个长度为旧长度2倍的Entry数组
            Entry[] newTab = new Entry[newLen];
            //记录插入的有效Entry节点数
            int count = 0;

             /**
             * 从下标0开始,逐个向后遍历插入到新的table当中
             * 通过hashcode & len - 1计算下标,如果该位置已经有Entry数组,则通过线性探测向后探测插入
             */
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {//如遇到key已经为null,则value设置null,方便GC回收
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            // 重新设置扩容的阈值
            setThreshold(newLen);
            // 更新size
            size = count;
             // 指向新的Entry数组
            table = newTab;
        }


    }

ThreadLocalMap.set()

ThreadLocalMap.expungeStaleEntry()

ThreadLocalMap.remove()

三 案例

目录结构:

在这里插入图片描述HttpFilter.java

package com.lyy.threadlocal.config;

import lombok.extern.slf4j.Slf4j;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

@Slf4j
public class HttpFilter implements Filter {

//初始化需要做的事情
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    //核心操作在这个里面
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest)servletRequest;
//        request.getSession().getAttribute("user");
        System.out.println("do filter:"+Thread.currentThread().getId()+":"+request.getServletPath());
        RequestHolder.add(Thread.currentThread().getId());
        //让这个请求完,,同时做下一步处理
        filterChain.doFilter(servletRequest,servletResponse);


    }

    //不再使用的时候做的事情
    @Override
    public void destroy() {

    }
}

HttpInterceptor.java

package com.lyy.threadlocal.config;

import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class HttpInterceptor extends HandlerInterceptorAdapter {

    //接口处理之前
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("preHandle:");
        return true;
    }

    //接口处理之后
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        RequestHolder.remove();
       System.out.println("afterCompletion");

        return;
    }
}

RequestHolder.java

package com.lyy.threadlocal.config;

public class RequestHolder {

    private final static ThreadLocal<Long> requestHolder = new ThreadLocal<>();//

    //提供方法传递数据
    public static void add(Long id){
        requestHolder.set(id);

    }

    public static Long getId(){
        //传入了当前线程的ID,到底层Map里面去取
        return requestHolder.get();
    }

    //移除变量信息,否则会造成逸出,导致内容永远不会释放掉
    public static void remove(){
        requestHolder.remove();
    }
}

ThreadLocalController.java

package com.lyy.threadlocal.controller;

import com.lyy.threadlocal.config.RequestHolder;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
@RequestMapping("/thredLocal")
public class ThreadLocalController {

    @RequestMapping("test")
    @ResponseBody
    public Long test(){
        return RequestHolder.getId();
    }

}

ThreadlocalDemoApplication.java

package com.lyy.threadlocal;

import com.lyy.threadlocal.config.HttpFilter;
import com.lyy.threadlocal.config.HttpInterceptor;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;

@SpringBootApplication
public class ThreadlocalDemoApplication extends WebMvcConfigurerAdapter {

    public static void main(String[] args) {
        SpringApplication.run(ThreadlocalDemoApplication.class, args);
    }

    @Bean
    public FilterRegistrationBean httpFilter(){
        FilterRegistrationBean registrationBean = new FilterRegistrationBean();
        registrationBean.setFilter(new HttpFilter());
        registrationBean.addUrlPatterns("/thredLocal/*");


        return registrationBean;
    }


    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new HttpInterceptor()).addPathPatterns("/**");
    }

}

输入:http://localhost:8080/thredLocal/test

后台打印:

do filter:35:/thredLocal/test preHandle:
afterCompletion

四 总结

1、ThreadLocal是通过每个线程单独一份存储空间,每个ThreadLocal只能保存一个变量副本。 2、相比于Synchronized,ThreadLocal具有线程隔离的效果,只有在线程内才能获取到对应的值,线程外则不能访问到想要的值,很好的实现了线程封闭。 3、每次使用完ThreadLocal,都调用它的remove()方法,清除数据,避免内存泄漏的风险 4、通过上面的源码分析,我们也可以看到大神在写代码的时候会考虑到整体实现的方方面面,一些逻辑上的处理是真严谨的,我们在看源代码的时候不能只是做了解,也要看到别人实现功能后面的目的。

源码地址:https://github.com/839022478/other/tree/master/threadlocal_demo

Copyright© 2013-2020

All Rights Reserved 京ICP备2023019179号-8