diff --git a/server/directive.go b/server/directive.go new file mode 100644 index 0000000..c5d5538 --- /dev/null +++ b/server/directive.go @@ -0,0 +1,22 @@ +package server + +import "net/http" + +type Directive func(args ...string) Middleware + +var Add_Header Directive = func(args ...string) Middleware { + return func(w http.ResponseWriter, r *http.Request, next func()) { + w.Header().Add(args[0], args[1]) + } +} + +var Set_Header Directive = func(args ...string) Middleware { + return func(w http.ResponseWriter, r *http.Request, next func()) { + w.Header().Set(args[0], args[1]) + } +} + +var DirectiveMap = map[string]Directive{ + "Set-Header": Set_Header, + "Add-Header": Add_Header, +} diff --git a/server/proxyupdater.go b/server/proxyupdater.go new file mode 100644 index 0000000..d3ccaf6 --- /dev/null +++ b/server/proxyupdater.go @@ -0,0 +1,62 @@ +package server + +import ( + "net/http" + "strings" + + "github.com/samber/lo" +) + +type ProxyRequestUpdater func(arg ...string) func(r *http.Request) + +var ProxyRequestUpdateMap = map[string]ProxyRequestUpdater{ + "Host": func(arg ...string) func(r *http.Request) { + return func(r *http.Request) { + r.Host = arg[0] + } + }, + "Path": func(arg ...string) func(r *http.Request) { + replace := arg[0] + with := arg[1] + return func(r *http.Request) { + r.URL.Path = r.URL.Path[len(replace):] + r.URL.Path = with + r.URL.Path + } + }, + "RemoveCookie": func(arg ...string) func(r *http.Request) { + cookie := arg[0] + return func(r *http.Request) { + cookies := r.Cookies() + _, index, ok := lo.FindIndexOf(cookies, func(v *http.Cookie) bool { + return v.Name == cookie + }) + if ok { + r.Header.Del("Cookie") + if len(cookies) == 1 { + return + } + cookies = append(cookies[:index], cookies[index+1:]...) + for _, cookie := range cookies { + r.AddCookie(cookie) + } + } + + } + }, +} + +func GetUpdaterFn(directive string) func(r *http.Request) { + strs := strings.Split(directive, " ") + if len(strs) > 1 { + updater, ok := ProxyRequestUpdateMap[strs[0]] + if ok { + return updater(strs[1:]...) + } + } else { + updater, ok := ProxyRequestUpdateMap[directive] + if ok { + return updater() + } + } + return nil +} diff --git a/server/server.go b/server/server.go index cfc8ef5..2efcad0 100644 --- a/server/server.go +++ b/server/server.go @@ -187,3 +187,44 @@ func NewRestMux(path string) *RestMux { } return ret } + +type ServerMux struct { + http.Handler + directiveHandlers *MiddlewareLink + handlers map[string]http.Handler + paths []string +} + +func (s *ServerMux) Handle(pattern string, handler http.Handler) { + if s.handlers == nil { + s.handlers = make(map[string]http.Handler) + } + s.handlers[pattern] = handler + s.paths = append(s.paths, pattern) + // 自定义比较函数排序s.paths + sort.Slice(s.paths, func(i, j int) bool { + return len(s.paths[i]) > len(s.paths[j]) || s.paths[i] > s.paths[j] + }) +} + +func (s *ServerMux) AddDirective(directiveStr string) { + //TODO: 根据字符串内容生成一个中间件链,等directive实现再来补充逻辑 + strs := strings.Split(directiveStr, " ") + directiveName := strs[0] + params := strs[1:] + directive, ok := DirectiveMap[directiveName] + if ok { + s.directiveHandlers.Add(directive(params...)) + } +} + +func (s *ServerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + for _, p := range s.paths { + if strings.HasPrefix(r.URL.Path, p) { + s.directiveHandlers.ServeHTTP(w, r) + s.handlers[p].ServeHTTP(w, r) + return + } + } + http.NotFound(w, r) +}