手写SpringBoot项目XSS攻击过滤器实现

一、先来个简介

什么是XSS?

百度百科的解释: XSS又叫CSS  (Cross Site Script) ,跨站脚本攻击。它指的是恶意攻击者往Web页面里插入恶意html代码,当用户浏览该页之时,嵌入其中Web里面的html代码会被执行,从而达到恶意用户的特殊目的。

它与SQL注入攻击类似,SQL注入攻击中以SQL语句作为用户输入,从而达到查询/修改/删除数据的目的,而在xss攻击中,通过插入恶意脚本,实现对用户游览器的控制,获取用户的一些信息。

二、XSS分类

xss攻击可以分成两种类型:

1.非持久型攻击
2.持久型攻击

非持久型xss攻击:顾名思义,非持久型xss攻击是一次性的,仅对当次的页面访问产生影响。非持久型xss攻击要求用户访问一个被攻击者篡改后的链接,用户访问该链接时,被植入的攻击脚本被用户游览器执行,从而达到攻击目的。

持久型xss攻击:持久型xss,会把攻击者的数据存储在服务器端,攻击行为将伴随着攻击数据一直存在。 

也可以分成三类:

反射型:经过后端,不经过数据库

存储型:经过后端,经过数据库

三、废话不多说直接上代码

先加pom文件加上依赖

<dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-text</artifactId>
            <version>1.4</version>
   </dependency>

1.首先是要写个过滤器的包装类,这也是实现XSS攻击过滤的核心代码。

package com.hrt.zxxc.fxspg.xss;

import com.alibaba.fastjson.JSON;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;

/**
 * @program: fxspg
 * @description: XSS过滤具体核心代码
 * @author: liumingyu
 * @date: 2020-01-10 14:28
 **/
public class XssAndSqlHttpServletRequestWrapper extends HttpServletRequestWrapper {

    /**
     * @return
     * @Author liumingyu
     * @Description //TODO 构造函数,传入参数,执行超类
     * @Date 2020/1/10 2:29 下午
     * @Param [request]
     **/
    public XssAndSqlHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    /**
     * @return java.lang.String
     * @Author liumingyu
     * @Description //TODO 重写getParameter方法 ,getParameter方法是直接通过request获得querystring类型的入参调用的方法
     * @Date 2020/1/10 2:31 下午
     * @Param [name]
     **/
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (!StringUtils.isEmpty(value)) {
            //调用Apache的工具类:StringEscapeUtils.escapeHtml4
            value = StringEscapeUtils.escapeHtml4(value);
        }
        return value;
    }

    /**
     * @return java.lang.String[]
     * @Author liumingyu
     * @Description //TODO 重写getParameterValues
     * @Date 2020/1/10 2:32 下午
     * @Param [name]
     **/
    @Override
    public String[] getParameterValues(String name) {
        String[] parameterValues = super.getParameterValues(name);
        if (parameterValues == null) {
            return null;
        }
        for (int i = 0; i < parameterValues.length; i++) {
            String value = parameterValues[i];
            //调用Apache的工具类:StringEscapeUtils.escapeHtml4
            parameterValues[i] = StringEscapeUtils.escapeHtml4(value);
        }
        return parameterValues;
    }

    @Override
    public String getHeader(String name) {
        return StringEscapeUtils.escapeHtml4(super.getHeader(name));
    }

    @Override
    public String getQueryString() {
        return StringEscapeUtils.escapeHtml4(super.getQueryString());
    }

    /**
     * @return javax.servlet.ServletInputStream
     * @Author liumingyu
     * @Description //TODO 过滤JSON数据中的XSS攻击
     * @Date 2020/1/10 4:58 下午
     * @Param []
     **/
    @Override
    public ServletInputStream getInputStream() throws IOException {
        //调用方法将流数据return为String
        String str = getRequestBody(super.getInputStream());
        //如果str为"",则返回0
        if ("".equals(str)) {
            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return 0;
                }

                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {

                }
            };
        }
        //将数据存放至map
        Map<String, Object> map = JSON.parseObject(str, Map.class);
        //声明个存放过滤后数据的hashMap
        Map<String, Object> resultMap = new HashMap<>(map.size());
        //开始遍历数据
        for (String key : map.keySet()) {
            Object val = map.get(key);
            //如果key=富文本字段名,就不去过滤
            if ("content".equals(key)) {
                //不过滤
                resultMap.put(key, val);
            } else {
                //不为富文本字段才会过滤
                if (map.get(key) instanceof String) {
                    //通过escapeHtml4去过滤
                    resultMap.put(key, StringEscapeUtils.escapeHtml4(val.toString()));
                } else {
                    //不过滤
                    resultMap.put(key, val);
                }
            }
        }
        str = JSON.toJSONString(resultMap);
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(str.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

        };
    }

    /**
     * @return java.lang.String
     * @Author liumingyu
     * @Description //TODO 获取JSON数据
     * @Date 2020/1/10 4:58 下午
     * @Param [stream]
     **/
    private String getRequestBody(InputStream stream) {
        String line = "";
        StringBuilder body = new StringBuilder();
        int counter = 0;
        // 读取POST提交的数据内容
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName("UTF-8")));
        try {
            while ((line = reader.readLine()) != null) {
                //拼接读取到的数据
                body.append(line);
                counter++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (body == null) {
            return "";
        }
        //最后返回数据
        return body.toString();
    }

}

2.看到这里你就已经完成了一半了加油!接下来的事情很简单写个过滤器就over 。

package com.hrt.zxxc.fxspg.xss;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * @program: fxspg
 * @description: XSS过滤器
 * @author: liumingyu
 * @date: 2020-01-10 14:36
 **/

@WebFilter
@Component
public class XssFilter implements Filter {

    /**
     * @return void
     * @Author liumingyu
     * @Description //TODO 重写init
     * @Date 2020/1/10 2:38 下午
     * @Param [filterConfig]
     **/
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    /**
     * @return void
     * @Author liumingyu
     * @Description //TODO 重写doFilter,将请求进行xss过滤
     * @Date 2020/1/10 2:41 下午
     * @Param [servletRequest, servletResponse, filterChain]
     **/
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
                         FilterChain filterChain) throws IOException, ServletException {
        //获取请求数据
        HttpServletRequest req = (HttpServletRequest) servletRequest;
        //获取请求的url路径
        String path = ((HttpServletRequest) servletRequest).getServletPath();
        //声明要被忽略请求的数组
        String[] exclusionsUrls = {".js", ".gif", ".jpg", ".png", ".css", ".ico"};
        //遍历忽略的请求数组,若该接口url为忽略的就调用原本的过滤器,不走xss过滤
        for (String str : exclusionsUrls) {
            if (path.contains(str)) {
                filterChain.doFilter(servletRequest, servletResponse);
                return;
            }
        }
        //将请求放入XSS请求包装器中,返回过滤后的值
        XssAndSqlHttpServletRequestWrapper xssRequestWrapper = new XssAndSqlHttpServletRequestWrapper(req);
        filterChain.doFilter(xssRequestWrapper, servletResponse);
    }

    @Override
    public void destroy() {

    }

}

注:上面的注释我都写得比较全了,不过多解释,so easy是不是!

xss

相关推荐