上两篇博文手写spring核心原理Version1手写spring核心原理Version2分别介绍了如何完成一个自动注入、以及如何用设计模式进行重构,接下来这篇将仿照SpringMVC对参数列表以及methodMapping进行重构。

重构MyDispatcherServlet

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

public class MyDispatcherServlet extends HttpServlet{

	private static final String LOCATION = "contextConfigLocation";
	public static final String SCAN_PACKAGE = "scanPackage";

	private Properties properties = new Properties();

	private List<String> classNames = new ArrayList<>();

	private Map<String,Object> iocContainer = new HashMap<>();

	private List<Handler> handlerMapping = new ArrayList<>();

	public MyDispatcherServlet(){ super(); }
	

	public void init(ServletConfig config) {

		//1、加载配置文件
		loadConfigurations(config.getInitParameter(LOCATION));

		//2、扫描所有相关的类
		scanPackages(properties.getProperty(SCAN_PACKAGE));
		
		//3、初始化所有相关类的实例,并保存到IOC容器中
		initInstances();
		
		//4、依赖注入
		autowireInstance();

		//5、构造HandlerMapping
		initHandlerMapping();
	}
	
	protected void doGet(HttpServletRequest req, HttpServletResponse resp)
			throws ServletException, IOException {
		this.doPost(req, resp);
	}

	protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
		try{
			doDispatch(req,resp);
			
		}catch(Exception e){
			resp.getWriter().write("500 Exception,Details:\r\n" + Arrays.toString(e.getStackTrace()).replaceAll("\\[|\\]", "").replaceAll(",\\s", "\r\n"));
		}
	}
	
	private void scanPackages(String packageName){
		URL url = this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));

		Stream.of(new File(url.getFile()).listFiles()).forEach(file -> {
			if (file.isDirectory()){
				scanPackages(packageName + "." + file.getName());
			}else {
				classNames.add(packageName + "." + file.getName().replace(".class", "").trim());
			}

		});
	}

	private void loadConfigurations(String location){
		try (InputStream input = this.getClass().getClassLoader().getResourceAsStream(location)){
			properties.load(input);
		}catch(Exception e){
			e.printStackTrace();
		}
	}

	private void initInstances(){
		if(classNames.size() == 0){
			return;
		}
		
		try{
			for (String className : classNames) {
				Class<?> clazz = Class.forName(className);
				String beanName;

				if(clazz.isAnnotationPresent(GPController.class)){
					beanName = getBeanName(clazz,clazz.getAnnotation(GPController.class).value());
				}else if(clazz.isAnnotationPresent(GPService.class)){
					beanName = getBeanName(clazz,clazz.getAnnotation(GPService.class).value());
				}else {
					continue;
				}

				final Object instance = clazz.newInstance();
				iocContainer.put(beanName, instance);
				Stream.of(clazz.getInterfaces()).forEach(i -> iocContainer.put(i.getName(),instance));
			}
		}catch(Exception e){
			e.printStackTrace();
		}
		
	}

	private void autowireInstance(){
		if(iocContainer.isEmpty()){
			return;
		}
		
		for (Entry<String, Object> entry : iocContainer.entrySet()) {
			Field [] fields = entry.getValue().getClass().getDeclaredFields();
			for (Field field : fields) {
				if(!field.isAnnotationPresent(GPAutowired.class)){
					continue;
				}
				
				String beanName = field.getAnnotation(GPAutowired.class).value().trim();
				if("".equals(beanName)){
					beanName = field.getType().getName();
				}

				field.setAccessible(true);
				try {
					field.set(entry.getValue(), iocContainer.get(beanName));
				} catch (Exception e) {
					e.printStackTrace();
					continue ;
				}
			}
		}
	}
	
	private void initHandlerMapping(){
		if(iocContainer.isEmpty()){
			return;
		}
		
		for (Entry<String, Object> entry : iocContainer.entrySet()) {
			Class<?> clazz = entry.getValue().getClass();
			if(!clazz.isAnnotationPresent(GPController.class)){
				continue;
			}
			
			String baseUrl = "";
			if(clazz.isAnnotationPresent(GPRequestMapping.class)){
				baseUrl = clazz.getAnnotation(GPRequestMapping.class).value();
			}
			
			Method [] methods = clazz.getMethods();
			for (Method method : methods) {
				if(!method.isAnnotationPresent(GPRequestMapping.class)){
					continue;
				}
				
				String regex =  baseUrl + method.getAnnotation(GPRequestMapping.class).value();
				Pattern pattern = Pattern.compile(regex);
				handlerMapping.add(new Handler(pattern,entry.getValue(),method));
			}
		}
		
	}
	
	private String getBeanName(Class<?> clazz,String name){
		if(!"".equals(name.trim())){
			return name;
		}

		char [] chars = clazz.getSimpleName().toCharArray();
		chars[0] += 32;
		return String.valueOf(chars);
	}
	
	private void doDispatch(HttpServletRequest req,HttpServletResponse resp) throws Exception{
		try{
			Handler handler = getHandler(req);
			
			if(handler == null){
				resp.getWriter().write("404 Not Found");
				return;
			}
			
			Class<?> [] paramTypes = handler.method.getParameterTypes();
			//保存所有需要自动赋值的参数值
			Object [] paramValues = new Object[paramTypes.length];
			
			
			Map<String,String[]> params = req.getParameterMap();
			for (Entry<String, String[]> param : params.entrySet()) {
				String value = Arrays.toString(param.getValue()).replaceAll("[\\[\\]]", "").replaceAll(",\\s", ",");
				
				if(!handler.paramIndexMapping.containsKey(param.getKey())){
					continue;
				}

				int index = handler.paramIndexMapping.get(param.getKey());
				paramValues[index] = convert(paramTypes[index],value);
			}
			
			
			//设置方法中的request和response对象
			int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
			paramValues[reqIndex] = req;

			int respIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
			paramValues[respIndex] = resp;

			handler.method.invoke(handler.controller, paramValues);
			
		}catch(Exception e){
			throw e;
		}
	}
	
	private Handler getHandler(HttpServletRequest req) {
		if(handlerMapping.isEmpty()){
			return null;
		}
		
		String url = req.getRequestURI().replace(req.getContextPath(), "");

		for (Handler handler : handlerMapping) {
			try{
				Matcher matcher = handler.pattern.matcher(url);
				if(!matcher.matches()){
					continue;
				}
				
				return handler;
			}catch(Exception e){
				throw e;
			}
		}
		return null;
	}

	//url传过来的参数都是String类型的,HTTP是基于字符串协议
	//只需要把String转换为任意类型就好
	private Object convert(Class<?> type,String value){
		if(Integer.class == type){
			return Integer.valueOf(value);
		}
		//如果还有double或者其他类型,继续加if
		//这时候,我们应该想到策略模式了
		//在这里暂时不实现
		return value;
	}

	/**
	 * Handler记录Controller中的RequestMapping和Method的对应关系
	 */
	private class Handler{
		protected Object controller;
		protected Method method;
		protected Pattern pattern;
		protected Map<String,Integer> paramIndexMapping;	//参数顺序
		
		protected Handler(Pattern pattern,Object controller,Method method){
			this.controller = controller;
			this.method = method;
			this.pattern = pattern;
			
			paramIndexMapping = new HashMap<>();
			putParamIndexMapping(method);
		}
		
		private void putParamIndexMapping(Method method){
			//提取方法中加了注解的参数
			Annotation [] [] pa = method.getParameterAnnotations();
			for (int i = 0; i < pa.length ; i ++) {
				for(Annotation a : pa[i]){
					if(a instanceof GPRequestParam){
						String paramName = ((GPRequestParam) a).value();
						if(!"".equals(paramName.trim())){
							paramIndexMapping.put(paramName, i);
						}
					}
				}
			}
			
			//提取方法中的request和response参数
			Class<?> [] paramsTypes = method.getParameterTypes();
			for (int i = 0; i < paramsTypes.length ; i ++) {
				Class<?> type = paramsTypes[i];
				if(type == HttpServletRequest.class ||
				   type == HttpServletResponse.class){
					paramIndexMapping.put(type.getName(),i);
				}
			}
		}
	}
}
  • 该demo源码参考自书籍《Spring 5核心原理与30个类手写实战》
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐