diff --git a/server/middleware.go b/server/middleware.go index f68ecb3..05bee28 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -4,6 +4,7 @@ import ( "container/list" "encoding/json" "net/http" + "reflect" "strings" "git.pyer.club/kingecg/gohttpd/model" @@ -15,26 +16,41 @@ type MiddlewareLink struct { *list.List } +func IsEqualMiddleware(a, b Middleware) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() +} func (ml *MiddlewareLink) Add(m Middleware) { + if m == nil { + return + } + if ml.List.Len() == 0 { ml.PushBack(m) } else { + if IsEqualMiddleware(m, Done) { + return + } el := ml.Back() ml.InsertBefore(m, el) } } func (ml *MiddlewareLink) ServeHTTP(w http.ResponseWriter, r *http.Request) bool { - canContinue := false + canContinue := true next := func() { canContinue = true } for e := ml.Front(); e != nil && canContinue; e = e.Next() { + canContinue = false e.Value.(Middleware)(w, r, next) if !canContinue { break - } else { - canContinue = false } } return canContinue diff --git a/server/middleware_test.go b/server/middleware_test.go new file mode 100644 index 0000000..d59e43d --- /dev/null +++ b/server/middleware_test.go @@ -0,0 +1,86 @@ +package server + +import ( + "container/list" + "net/http" + "testing" +) + +func TestNewMiddlewareLink(t *testing.T) { + ml := NewMiddlewareLink() + + if ml == nil { + t.Errorf("NewMiddlewareLink() = nil, want %T", &MiddlewareLink{list.New()}) + } + if ml.Len() == 0 { + t.Errorf("NewMiddlewareLink() returned list with length %d, want %d", 0, 1) + } +} + +func TestMiddlewareLink_Add(t *testing.T) { + ml := NewMiddlewareLink() + element := Done + + if ml.Len() != 1 { + t.Errorf("Add() did not add element to the list, length = %d, want %d", ml.Len(), 1) + } + ml.Add(element) + if ml.Len() != 1 { + t.Error("Should not Add Done twice") + } + + t1 := ml.ServeHTTP(nil, nil) + + if !t1 { + t.Error("ServeHTTP() returned false, want true") + } + + B := func(w http.ResponseWriter, r *http.Request, next func()) { + + } + ml.Add(B) + if ml.Len() != 2 { + t.Error("Should has 2 middleware but got 1") + } + t2 := ml.ServeHTTP(nil, nil) + if t2 { + t.Error("ServeHTTP() returned true, want false") + } +} + +func TestNewRoute(t *testing.T) { + // 创建一个测试用的http.Handler + handleFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // 调用被测试的函数 + route := NewRoute("GET", "/test", handleFn) + + // 验证生成的route是否符合预期 + if route.Method != "GET" { + t.Errorf("Expected method to be GET, got %s", route.Method) + } + if route.Path != "/test" { + t.Errorf("Expected path to be /test, got %s", route.Path) + } + if route.Handler == nil { + t.Error("Expected handler to be set") + } + if route.matcher == nil { + t.Error("Expected matcher to be set") + } + if route.middles == nil { + t.Error("Expected middles to be set") + } + + route2 := NewRoute("GET", "/test2/:id", handleFn) + + if route2.matcher.Params[0] != "id" { + t.Errorf("Expected param name to be id, got %s", route2.matcher.Params[0]) + } + + if route2.matcher.Reg == nil { + t.Error("Expected matcher.Reg to be set") + } +} diff --git a/server/server.go b/server/server.go index c0cd3f3..90d6f87 100644 --- a/server/server.go +++ b/server/server.go @@ -50,6 +50,14 @@ func (route *Route) Match(r *http.Request) bool { func (route *Route) Add(m Middleware) { route.middles.Add(m) } + +// NewRoute 返回一个新的Route实例 +// 参数: +// - method: 请求方法 +// - path: 请求路径 +// - handleFn: http.Handler处理函数 +// 返回值: +// - *Route: 一个指向Route的指针 func NewRoute(method string, path string, handleFn http.Handler) *Route { ret := &Route{ Method: method, @@ -57,7 +65,7 @@ func NewRoute(method string, path string, handleFn http.Handler) *Route { middles: NewMiddlewareLink(), } p := ParseUrl(path) - //使用handleFn构建handler + // 使用handleFn构建handler ret.Handler = handleFn ret.matcher = &p return ret