add parse middleware

This commit is contained in:
kingecg 2023-12-11 23:46:40 +08:00
parent e22b39513e
commit d6752cd324
4 changed files with 221 additions and 9 deletions

View File

@ -1,7 +1,9 @@
package server
import (
"encoding/json"
"net/http"
"strings"
"git.pyer.club/kingecg/gohttpd/model"
)
@ -24,9 +26,37 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
}
}
func Parse(w http.ResponseWriter, r *http.Request, next func()) {
func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) {
if r.Method == "POST" || r.Method == "PUT" {
r.ParseForm()
//判断r的content-type是否是application/x-www-form-urlencoded
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
r.ParseForm()
next()
return
}
if r.Header.Get("Content-Type") == "multipart/form-data" {
r.ParseMultipartForm(r.ContentLength)
next()
return
}
// 判断r的content-type是否是application/json
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
var t T
// 读取json数据
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ctx := r.Context()
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
if m != nil {
m["data"] = t
}
next()
return
}
}
next()
}

View File

@ -1,12 +1,15 @@
package server
import (
"context"
"net/http"
"strings"
"github.com/samber/lo"
)
type RequestCtxKey string
// 可以嵌套的Rest http server mux
type RestMux struct {
Path string
@ -20,10 +23,14 @@ func (mux *RestMux) Use(m Middleware) {
}
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
canContinue := false
c := r.Context()
data := map[string]interface{}{}
cn := context.WithValue(c, RequestCtxKey("data"), data)
newRequest := r.WithContext(cn)
if len(mux.middlewares) > 0 {
for _, m := range mux.middlewares {
canContinue = false
m(w, r, func() { canContinue = true })
m(w, newRequest, func() { canContinue = true })
if !canContinue {
return
}
@ -31,18 +38,18 @@ func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
_, has := lo.Find(mux.rmuxPaths, func(s string) bool {
return strings.HasPrefix(r.URL.Path, s)
return strings.HasPrefix(newRequest.URL.Path, s)
})
if has {
mux.imux.ServeHTTP(w, r)
mux.imux.ServeHTTP(w, newRequest)
return
}
r.URL.Path = "/" + strings.ToLower(r.Method) + r.URL.Path
r.RequestURI = "/" + strings.ToLower(r.Method) + r.RequestURI
h, _ := mux.imux.Handler(r)
newRequest.URL.Path = "/" + strings.ToLower(newRequest.Method) + newRequest.URL.Path
newRequest.RequestURI = "/" + strings.ToLower(newRequest.Method) + newRequest.RequestURI
h, _ := mux.imux.Handler(newRequest)
h.ServeHTTP(w, r)
h.ServeHTTP(w, newRequest)
}
func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) {

67
server/url.go Normal file
View File

@ -0,0 +1,67 @@
package server
import (
"net/url"
"regexp"
"strings"
)
type UrlParamMatcher struct {
Params []string
Reg *regexp.Regexp
}
// 解析URL函数
// 参数:
//
// u待解析的URL字符串
//
// 返回值:
//
// 解析后的UrlParamMatcher结构体
func ParseUrl(u string) UrlParamMatcher {
ret := UrlParamMatcher{}
uo, _ := url.Parse(u)
// 判断路径是否非空且非根路径
if uo.Path != "" && uo.Path != "/" {
// 将路径按斜杠分割成切片
us := strings.Split(uo.Path, "/")
for index, v := range us {
// 判断是否以冒号开头
if strings.HasPrefix(v, ":") {
// 去除冒号前缀
param := strings.TrimPrefix(v, ":")
// 将参数添加到Params切片中
ret.Params = append(ret.Params, param)
// 将参数名作为正则表达式的命名捕获组
us[index] = "(?P<" + param + ">.*)"
}
}
// 如果存在参数,则将路径编译为正则表达式
if len(ret.Params) > 0 {
ret.Reg, _ = regexp.Compile("^" + strings.Join(us, "/") + "$")
}
}
return ret
}
func MatchUrlParam(u string, matcher *UrlParamMatcher) (map[string]string, bool) {
if matcher.Reg != nil {
uo, _ := url.Parse(u)
if uo.Path == "" || uo.Path == "/" {
return nil, false
}
matches := matcher.Reg.FindStringSubmatch(uo.Path)
if len(matches) > 0 {
params := make(map[string]string)
for i, name := range matcher.Params {
params[name] = matches[i+1]
}
return params, true
}
}
return nil, false
}

108
server/url_test.go Normal file
View File

@ -0,0 +1,108 @@
package server
import (
"regexp"
"strings"
"testing"
)
func TestParseUrl(t *testing.T) {
tests := []struct {
name string
input string
output UrlParamMatcher
}{
{
name: "Empty Path",
input: "http://example.com",
output: UrlParamMatcher{
Params: []string{},
Reg: nil,
},
},
{
name: "Path with No Parameters",
input: "http://example.com/path/to/file",
output: UrlParamMatcher{
Params: []string{},
Reg: nil,
},
},
{
name: "Path with Parameters",
input: "http://example.com/:param1/:param2/:file",
output: UrlParamMatcher{
Params: []string{"param1", "param2", "file"},
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)/(?P<file>.*)$"),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ParseUrl(test.input)
if !UrlMatcherEqual(result, test.output) {
t.Errorf("Expected %v, but got %v", test.output, result)
}
})
}
}
func UrlMatcherEqual(a, b UrlParamMatcher) bool {
if strings.Join(a.Params, ";") != strings.Join(b.Params, ";") {
return false
}
if a.Reg == nil || b.Reg == nil {
return a.Reg == nil && b.Reg == nil
}
if a.Reg.String() != b.Reg.String() {
return false
}
return true
}
func TestMatchUrlParam(t *testing.T) {
matcher := &UrlParamMatcher{
Reg: nil,
Params: []string{"param1", "param2"},
}
u1 := "http://example.com/path"
expected1 := map[string]string(nil)
result1, found1 := MatchUrlParam(u1, matcher)
if len(result1) != len(expected1) || !allMatch(result1, expected1) {
t.Errorf("Test case 1 failed. Expected %v, but got %v", expected1, result1)
}
if found1 {
t.Errorf("Test case 1 failed. Expected not to find a match.")
}
u4 := "http://example.com/param1/param2"
matcher4 := &UrlParamMatcher{
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)$"),
Params: []string{"param1", "param2"},
}
expected4 := map[string]string{"param1": "param1", "param2": "param2"}
result4, found4 := MatchUrlParam(u4, matcher4)
if len(result4) != len(expected4) || !allMatch(result4, expected4) {
t.Errorf("Test case 4 failed. Expected %v, but got %v", expected4, result4)
}
if !found4 {
t.Errorf("Test case 4 failed. Expected not to find a match.")
}
}
func allMatch(m1, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
v2, ok := m2[k]
if !ok || v1 != v2 {
return false
}
}
return true
}