add parse middleware
This commit is contained in:
parent
e22b39513e
commit
d6752cd324
|
@ -1,7 +1,9 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.pyer.club/kingecg/gohttpd/model"
|
"git.pyer.club/kingecg/gohttpd/model"
|
||||||
)
|
)
|
||||||
|
@ -24,9 +26,37 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse(w http.ResponseWriter, r *http.Request, next func()) {
|
func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) {
|
||||||
|
|
||||||
if r.Method == "POST" || r.Method == "PUT" {
|
if r.Method == "POST" || r.Method == "PUT" {
|
||||||
r.ParseForm()
|
//判断r的content-type是否是application/x-www-form-urlencoded
|
||||||
|
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
|
||||||
|
r.ParseForm()
|
||||||
|
next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Content-Type") == "multipart/form-data" {
|
||||||
|
r.ParseMultipartForm(r.ContentLength)
|
||||||
|
next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 判断r的content-type是否是application/json
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
var t T
|
||||||
|
// 读取json数据
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := r.Context()
|
||||||
|
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
|
||||||
|
if m != nil {
|
||||||
|
m["data"] = t
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
next()
|
next()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RequestCtxKey string
|
||||||
|
|
||||||
// 可以嵌套的Rest http server mux
|
// 可以嵌套的Rest http server mux
|
||||||
type RestMux struct {
|
type RestMux struct {
|
||||||
Path string
|
Path string
|
||||||
|
@ -20,10 +23,14 @@ func (mux *RestMux) Use(m Middleware) {
|
||||||
}
|
}
|
||||||
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
canContinue := false
|
canContinue := false
|
||||||
|
c := r.Context()
|
||||||
|
data := map[string]interface{}{}
|
||||||
|
cn := context.WithValue(c, RequestCtxKey("data"), data)
|
||||||
|
newRequest := r.WithContext(cn)
|
||||||
if len(mux.middlewares) > 0 {
|
if len(mux.middlewares) > 0 {
|
||||||
for _, m := range mux.middlewares {
|
for _, m := range mux.middlewares {
|
||||||
canContinue = false
|
canContinue = false
|
||||||
m(w, r, func() { canContinue = true })
|
m(w, newRequest, func() { canContinue = true })
|
||||||
if !canContinue {
|
if !canContinue {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -31,18 +38,18 @@ func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, has := lo.Find(mux.rmuxPaths, func(s string) bool {
|
_, has := lo.Find(mux.rmuxPaths, func(s string) bool {
|
||||||
return strings.HasPrefix(r.URL.Path, s)
|
return strings.HasPrefix(newRequest.URL.Path, s)
|
||||||
})
|
})
|
||||||
if has {
|
if has {
|
||||||
mux.imux.ServeHTTP(w, r)
|
mux.imux.ServeHTTP(w, newRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r.URL.Path = "/" + strings.ToLower(r.Method) + r.URL.Path
|
newRequest.URL.Path = "/" + strings.ToLower(newRequest.Method) + newRequest.URL.Path
|
||||||
r.RequestURI = "/" + strings.ToLower(r.Method) + r.RequestURI
|
newRequest.RequestURI = "/" + strings.ToLower(newRequest.Method) + newRequest.RequestURI
|
||||||
h, _ := mux.imux.Handler(r)
|
h, _ := mux.imux.Handler(newRequest)
|
||||||
|
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, newRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) {
|
func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) {
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UrlParamMatcher struct {
|
||||||
|
Params []string
|
||||||
|
Reg *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析URL函数
|
||||||
|
// 参数:
|
||||||
|
//
|
||||||
|
// u:待解析的URL字符串
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
//
|
||||||
|
// 解析后的UrlParamMatcher结构体
|
||||||
|
func ParseUrl(u string) UrlParamMatcher {
|
||||||
|
ret := UrlParamMatcher{}
|
||||||
|
|
||||||
|
uo, _ := url.Parse(u)
|
||||||
|
|
||||||
|
// 判断路径是否非空且非根路径
|
||||||
|
if uo.Path != "" && uo.Path != "/" {
|
||||||
|
// 将路径按斜杠分割成切片
|
||||||
|
us := strings.Split(uo.Path, "/")
|
||||||
|
for index, v := range us {
|
||||||
|
// 判断是否以冒号开头
|
||||||
|
if strings.HasPrefix(v, ":") {
|
||||||
|
// 去除冒号前缀
|
||||||
|
param := strings.TrimPrefix(v, ":")
|
||||||
|
// 将参数添加到Params切片中
|
||||||
|
ret.Params = append(ret.Params, param)
|
||||||
|
// 将参数名作为正则表达式的命名捕获组
|
||||||
|
us[index] = "(?P<" + param + ">.*)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 如果存在参数,则将路径编译为正则表达式
|
||||||
|
if len(ret.Params) > 0 {
|
||||||
|
ret.Reg, _ = regexp.Compile("^" + strings.Join(us, "/") + "$")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchUrlParam(u string, matcher *UrlParamMatcher) (map[string]string, bool) {
|
||||||
|
if matcher.Reg != nil {
|
||||||
|
uo, _ := url.Parse(u)
|
||||||
|
if uo.Path == "" || uo.Path == "/" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
matches := matcher.Reg.FindStringSubmatch(uo.Path)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for i, name := range matcher.Params {
|
||||||
|
params[name] = matches[i+1]
|
||||||
|
}
|
||||||
|
return params, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseUrl(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
output UrlParamMatcher
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty Path",
|
||||||
|
input: "http://example.com",
|
||||||
|
output: UrlParamMatcher{
|
||||||
|
Params: []string{},
|
||||||
|
Reg: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path with No Parameters",
|
||||||
|
input: "http://example.com/path/to/file",
|
||||||
|
output: UrlParamMatcher{
|
||||||
|
Params: []string{},
|
||||||
|
Reg: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path with Parameters",
|
||||||
|
input: "http://example.com/:param1/:param2/:file",
|
||||||
|
output: UrlParamMatcher{
|
||||||
|
Params: []string{"param1", "param2", "file"},
|
||||||
|
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)/(?P<file>.*)$"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
result := ParseUrl(test.input)
|
||||||
|
if !UrlMatcherEqual(result, test.output) {
|
||||||
|
t.Errorf("Expected %v, but got %v", test.output, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UrlMatcherEqual(a, b UrlParamMatcher) bool {
|
||||||
|
|
||||||
|
if strings.Join(a.Params, ";") != strings.Join(b.Params, ";") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Reg == nil || b.Reg == nil {
|
||||||
|
return a.Reg == nil && b.Reg == nil
|
||||||
|
}
|
||||||
|
if a.Reg.String() != b.Reg.String() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchUrlParam(t *testing.T) {
|
||||||
|
matcher := &UrlParamMatcher{
|
||||||
|
Reg: nil,
|
||||||
|
Params: []string{"param1", "param2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
u1 := "http://example.com/path"
|
||||||
|
expected1 := map[string]string(nil)
|
||||||
|
result1, found1 := MatchUrlParam(u1, matcher)
|
||||||
|
if len(result1) != len(expected1) || !allMatch(result1, expected1) {
|
||||||
|
t.Errorf("Test case 1 failed. Expected %v, but got %v", expected1, result1)
|
||||||
|
}
|
||||||
|
if found1 {
|
||||||
|
t.Errorf("Test case 1 failed. Expected not to find a match.")
|
||||||
|
}
|
||||||
|
|
||||||
|
u4 := "http://example.com/param1/param2"
|
||||||
|
matcher4 := &UrlParamMatcher{
|
||||||
|
Reg: regexp.MustCompile("^/(?P<param1>.*)/(?P<param2>.*)$"),
|
||||||
|
Params: []string{"param1", "param2"},
|
||||||
|
}
|
||||||
|
expected4 := map[string]string{"param1": "param1", "param2": "param2"}
|
||||||
|
result4, found4 := MatchUrlParam(u4, matcher4)
|
||||||
|
if len(result4) != len(expected4) || !allMatch(result4, expected4) {
|
||||||
|
t.Errorf("Test case 4 failed. Expected %v, but got %v", expected4, result4)
|
||||||
|
}
|
||||||
|
if !found4 {
|
||||||
|
t.Errorf("Test case 4 failed. Expected not to find a match.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func allMatch(m1, m2 map[string]string) bool {
|
||||||
|
if len(m1) != len(m2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, v1 := range m1 {
|
||||||
|
v2, ok := m2[k]
|
||||||
|
if !ok || v1 != v2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
Loading…
Reference in New Issue