add parse middleware
This commit is contained in:
parent
e22b39513e
commit
d6752cd324
|
@ -1,7 +1,9 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.pyer.club/kingecg/gohttpd/model"
|
||||
)
|
||||
|
@ -24,9 +26,37 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
|
|||
}
|
||||
}
|
||||
|
||||
func Parse(w http.ResponseWriter, r *http.Request, next func()) {
|
||||
func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) {
|
||||
|
||||
if r.Method == "POST" || r.Method == "PUT" {
|
||||
r.ParseForm()
|
||||
//判断r的content-type是否是application/x-www-form-urlencoded
|
||||
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
|
||||
r.ParseForm()
|
||||
next()
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Content-Type") == "multipart/form-data" {
|
||||
r.ParseMultipartForm(r.ContentLength)
|
||||
next()
|
||||
return
|
||||
}
|
||||
// 判断r的content-type是否是application/json
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var t T
|
||||
// 读取json数据
|
||||
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
|
||||
if m != nil {
|
||||
m["data"] = t
|
||||
}
|
||||
next()
|
||||
return
|
||||
}
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type RequestCtxKey string
|
||||
|
||||
// 可以嵌套的Rest http server mux
|
||||
type RestMux struct {
|
||||
Path string
|
||||
|
@ -20,10 +23,14 @@ func (mux *RestMux) Use(m Middleware) {
|
|||
}
|
||||
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
canContinue := false
|
||||
c := r.Context()
|
||||
data := map[string]interface{}{}
|
||||
cn := context.WithValue(c, RequestCtxKey("data"), data)
|
||||
newRequest := r.WithContext(cn)
|
||||
if len(mux.middlewares) > 0 {
|
||||
for _, m := range mux.middlewares {
|
||||
canContinue = false
|
||||
m(w, r, func() { canContinue = true })
|
||||
m(w, newRequest, func() { canContinue = true })
|
||||
if !canContinue {
|
||||
return
|
||||
}
|
||||
|
@ -31,18 +38,18 @@ func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
_, has := lo.Find(mux.rmuxPaths, func(s string) bool {
|
||||
return strings.HasPrefix(r.URL.Path, s)
|
||||
return strings.HasPrefix(newRequest.URL.Path, s)
|
||||
})
|
||||
if has {
|
||||
mux.imux.ServeHTTP(w, r)
|
||||
mux.imux.ServeHTTP(w, newRequest)
|
||||
return
|
||||
}
|
||||
|
||||
r.URL.Path = "/" + strings.ToLower(r.Method) + r.URL.Path
|
||||
r.RequestURI = "/" + strings.ToLower(r.Method) + r.RequestURI
|
||||
h, _ := mux.imux.Handler(r)
|
||||
newRequest.URL.Path = "/" + strings.ToLower(newRequest.Method) + newRequest.URL.Path
|
||||
newRequest.RequestURI = "/" + strings.ToLower(newRequest.Method) + newRequest.RequestURI
|
||||
h, _ := mux.imux.Handler(newRequest)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
h.ServeHTTP(w, newRequest)
|
||||
}
|
||||
|
||||
func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) {
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type UrlParamMatcher struct {
|
||||
Params []string
|
||||
Reg *regexp.Regexp
|
||||
}
|
||||
|
||||
// 解析URL函数
|
||||
// 参数:
|
||||
//
|
||||
// u:待解析的URL字符串
|
||||
//
|
||||
// 返回值:
|
||||
//
|
||||
// 解析后的UrlParamMatcher结构体
|
||||
func ParseUrl(u string) UrlParamMatcher {
|
||||
ret := UrlParamMatcher{}
|
||||
|
||||
uo, _ := url.Parse(u)
|
||||
|
||||
// 判断路径是否非空且非根路径
|
||||
if uo.Path != "" && uo.Path != "/" {
|
||||
// 将路径按斜杠分割成切片
|
||||
us := strings.Split(uo.Path, "/")
|
||||
for index, v := range us {
|
||||
// 判断是否以冒号开头
|
||||
if strings.HasPrefix(v, ":") {
|
||||
// 去除冒号前缀
|
||||
param := strings.TrimPrefix(v, ":")
|
||||
// 将参数添加到Params切片中
|
||||
ret.Params = append(ret.Params, param)
|
||||
// 将参数名作为正则表达式的命名捕获组
|
||||
us[index] = "(?P<" + param + ">.*)"
|
||||
}
|
||||
}
|
||||
// 如果存在参数,则将路径编译为正则表达式
|
||||
if len(ret.Params) > 0 {
|
||||
ret.Reg, _ = regexp.Compile("^" + strings.Join(us, "/") + "$")
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func MatchUrlParam(u string, matcher *UrlParamMatcher) (map[string]string, bool) {
|
||||
if matcher.Reg != nil {
|
||||
uo, _ := url.Parse(u)
|
||||
if uo.Path == "" || uo.Path == "/" {
|
||||
return nil, false
|
||||
}
|
||||
matches := matcher.Reg.FindStringSubmatch(uo.Path)
|
||||
if len(matches) > 0 {
|
||||
params := make(map[string]string)
|
||||
for i, name := range matcher.Params {
|
||||
params[name] = matches[i+1]
|
||||
}
|
||||
return params, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseUrl(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
output UrlParamMatcher
|
||||
}{
|
||||
{
|
||||
name: "Empty Path",
|
||||
input: "http://example.com",
|
||||
output: UrlParamMatcher{
|
||||
Params: []string{},
|
||||
Reg: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Path with No Parameters",
|
||||
input: "http://example.com/path/to/file",
|
||||
output: UrlParamMatcher{
|
||||
Params: []string{},
|
||||
Reg: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Path with Parameters",
|
||||
input: "http://example.com/:param1/:param2/:file",
|
||||
output: UrlParamMatcher{
|
||||
Params: []string{"param1", "param2", "file"},
|
||||
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)/(?P<file>.*)$"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := ParseUrl(test.input)
|
||||
if !UrlMatcherEqual(result, test.output) {
|
||||
t.Errorf("Expected %v, but got %v", test.output, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func UrlMatcherEqual(a, b UrlParamMatcher) bool {
|
||||
|
||||
if strings.Join(a.Params, ";") != strings.Join(b.Params, ";") {
|
||||
return false
|
||||
}
|
||||
|
||||
if a.Reg == nil || b.Reg == nil {
|
||||
return a.Reg == nil && b.Reg == nil
|
||||
}
|
||||
if a.Reg.String() != b.Reg.String() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMatchUrlParam(t *testing.T) {
|
||||
matcher := &UrlParamMatcher{
|
||||
Reg: nil,
|
||||
Params: []string{"param1", "param2"},
|
||||
}
|
||||
|
||||
u1 := "http://example.com/path"
|
||||
expected1 := map[string]string(nil)
|
||||
result1, found1 := MatchUrlParam(u1, matcher)
|
||||
if len(result1) != len(expected1) || !allMatch(result1, expected1) {
|
||||
t.Errorf("Test case 1 failed. Expected %v, but got %v", expected1, result1)
|
||||
}
|
||||
if found1 {
|
||||
t.Errorf("Test case 1 failed. Expected not to find a match.")
|
||||
}
|
||||
|
||||
u4 := "http://example.com/param1/param2"
|
||||
matcher4 := &UrlParamMatcher{
|
||||
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)$"),
|
||||
Params: []string{"param1", "param2"},
|
||||
}
|
||||
expected4 := map[string]string{"param1": "param1", "param2": "param2"}
|
||||
result4, found4 := MatchUrlParam(u4, matcher4)
|
||||
if len(result4) != len(expected4) || !allMatch(result4, expected4) {
|
||||
t.Errorf("Test case 4 failed. Expected %v, but got %v", expected4, result4)
|
||||
}
|
||||
if !found4 {
|
||||
t.Errorf("Test case 4 failed. Expected not to find a match.")
|
||||
}
|
||||
}
|
||||
|
||||
func allMatch(m1, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v1 := range m1 {
|
||||
v2, ok := m2[k]
|
||||
if !ok || v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
Loading…
Reference in New Issue