diff --git a/server/middleware.go b/server/middleware.go index 4e27cb0..2d08971 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -1,7 +1,9 @@ package server import ( + "encoding/json" "net/http" + "strings" "git.pyer.club/kingecg/gohttpd/model" ) @@ -24,9 +26,37 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) { } } -func Parse(w http.ResponseWriter, r *http.Request, next func()) { +func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) { + if r.Method == "POST" || r.Method == "PUT" { - r.ParseForm() + //判断r的content-type是否是application/x-www-form-urlencoded + if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" { + r.ParseForm() + next() + return + } + if r.Header.Get("Content-Type") == "multipart/form-data" { + r.ParseMultipartForm(r.ContentLength) + next() + return + } + // 判断r的content-type是否是application/json + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + var t T + // 读取json数据 + if err := json.NewDecoder(r.Body).Decode(&t); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ctx := r.Context() + m := ctx.Value(RequestCtxKey("data")).(map[string]interface{}) + if m != nil { + m["data"] = t + } + next() + return + } } next() } diff --git a/server/server.go b/server/server.go index 313292d..cff0285 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,15 @@ package server import ( + "context" "net/http" "strings" "github.com/samber/lo" ) +type RequestCtxKey string + // 可以嵌套的Rest http server mux type RestMux struct { Path string @@ -20,10 +23,14 @@ func (mux *RestMux) Use(m Middleware) { } func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { canContinue := false + c := r.Context() + data := map[string]interface{}{} + cn := context.WithValue(c, RequestCtxKey("data"), data) + newRequest := r.WithContext(cn) if len(mux.middlewares) > 0 { for _, m := range mux.middlewares { canContinue = false - m(w, r, func() { canContinue = true }) + m(w, newRequest, func() { canContinue = true }) if !canContinue { return } @@ -31,18 +38,18 @@ func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } _, has := lo.Find(mux.rmuxPaths, func(s string) bool { - return strings.HasPrefix(r.URL.Path, s) + return strings.HasPrefix(newRequest.URL.Path, s) }) if has { - mux.imux.ServeHTTP(w, r) + mux.imux.ServeHTTP(w, newRequest) return } - r.URL.Path = "/" + strings.ToLower(r.Method) + r.URL.Path - r.RequestURI = "/" + strings.ToLower(r.Method) + r.RequestURI - h, _ := mux.imux.Handler(r) + newRequest.URL.Path = "/" + strings.ToLower(newRequest.Method) + newRequest.URL.Path + newRequest.RequestURI = "/" + strings.ToLower(newRequest.Method) + newRequest.RequestURI + h, _ := mux.imux.Handler(newRequest) - h.ServeHTTP(w, r) + h.ServeHTTP(w, newRequest) } func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) { diff --git a/server/url.go b/server/url.go new file mode 100644 index 0000000..e7e4235 --- /dev/null +++ b/server/url.go @@ -0,0 +1,67 @@ +package server + +import ( + "net/url" + "regexp" + "strings" +) + +type UrlParamMatcher struct { + Params []string + Reg *regexp.Regexp +} + +// 解析URL函数 +// 参数: +// +// u:待解析的URL字符串 +// +// 返回值: +// +// 解析后的UrlParamMatcher结构体 +func ParseUrl(u string) UrlParamMatcher { + ret := UrlParamMatcher{} + + uo, _ := url.Parse(u) + + // 判断路径是否非空且非根路径 + if uo.Path != "" && uo.Path != "/" { + // 将路径按斜杠分割成切片 + us := strings.Split(uo.Path, "/") + for index, v := range us { + // 判断是否以冒号开头 + if strings.HasPrefix(v, ":") { + // 去除冒号前缀 + param := strings.TrimPrefix(v, ":") + // 将参数添加到Params切片中 + ret.Params = append(ret.Params, param) + // 将参数名作为正则表达式的命名捕获组 + us[index] = "(?P<" + param + ">.*)" + } + } + // 如果存在参数,则将路径编译为正则表达式 + if len(ret.Params) > 0 { + ret.Reg, _ = regexp.Compile("^" + strings.Join(us, "/") + "$") + } + } + + return ret +} + +func MatchUrlParam(u string, matcher *UrlParamMatcher) (map[string]string, bool) { + if matcher.Reg != nil { + uo, _ := url.Parse(u) + if uo.Path == "" || uo.Path == "/" { + return nil, false + } + matches := matcher.Reg.FindStringSubmatch(uo.Path) + if len(matches) > 0 { + params := make(map[string]string) + for i, name := range matcher.Params { + params[name] = matches[i+1] + } + return params, true + } + } + return nil, false +} diff --git a/server/url_test.go b/server/url_test.go new file mode 100644 index 0000000..2813e21 --- /dev/null +++ b/server/url_test.go @@ -0,0 +1,108 @@ +package server + +import ( + "regexp" + "strings" + "testing" +) + +func TestParseUrl(t *testing.T) { + tests := []struct { + name string + input string + output UrlParamMatcher + }{ + { + name: "Empty Path", + input: "http://example.com", + output: UrlParamMatcher{ + Params: []string{}, + Reg: nil, + }, + }, + { + name: "Path with No Parameters", + input: "http://example.com/path/to/file", + output: UrlParamMatcher{ + Params: []string{}, + Reg: nil, + }, + }, + { + name: "Path with Parameters", + input: "http://example.com/:param1/:param2/:file", + output: UrlParamMatcher{ + Params: []string{"param1", "param2", "file"}, + Reg: regexp.MustCompile("^/(?P.*)/(?P.*)/(?P.*)$"), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := ParseUrl(test.input) + if !UrlMatcherEqual(result, test.output) { + t.Errorf("Expected %v, but got %v", test.output, result) + } + }) + } +} + +func UrlMatcherEqual(a, b UrlParamMatcher) bool { + + if strings.Join(a.Params, ";") != strings.Join(b.Params, ";") { + return false + } + + if a.Reg == nil || b.Reg == nil { + return a.Reg == nil && b.Reg == nil + } + if a.Reg.String() != b.Reg.String() { + return false + } + return true +} + +func TestMatchUrlParam(t *testing.T) { + matcher := &UrlParamMatcher{ + Reg: nil, + Params: []string{"param1", "param2"}, + } + + u1 := "http://example.com/path" + expected1 := map[string]string(nil) + result1, found1 := MatchUrlParam(u1, matcher) + if len(result1) != len(expected1) || !allMatch(result1, expected1) { + t.Errorf("Test case 1 failed. Expected %v, but got %v", expected1, result1) + } + if found1 { + t.Errorf("Test case 1 failed. Expected not to find a match.") + } + + u4 := "http://example.com/param1/param2" + matcher4 := &UrlParamMatcher{ + Reg: regexp.MustCompile("^/(?P.*)/(?P.*)$"), + Params: []string{"param1", "param2"}, + } + expected4 := map[string]string{"param1": "param1", "param2": "param2"} + result4, found4 := MatchUrlParam(u4, matcher4) + if len(result4) != len(expected4) || !allMatch(result4, expected4) { + t.Errorf("Test case 4 failed. Expected %v, but got %v", expected4, result4) + } + if !found4 { + t.Errorf("Test case 4 failed. Expected not to find a match.") + } +} + +func allMatch(m1, m2 map[string]string) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + v2, ok := m2[k] + if !ok || v1 != v2 { + return false + } + } + return true +}