From 35af0ea0ad3c32e9e6056131873702b6fdbfdf7b Mon Sep 17 00:00:00 2001 From: root <1@root.com> Date: Wed, 9 Jul 2025 02:07:30 +0900 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3request=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=BB=93=E6=9E=84=E4=BD=93=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?query=E5=92=8Cform=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- request_params.go | 222 ++++++++++++++++++++++++++++++++++ request_params_test.go | 264 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 486 insertions(+) create mode 100644 request_params.go create mode 100644 request_params_test.go diff --git a/request_params.go b/request_params.go new file mode 100644 index 0000000..943bf8f --- /dev/null +++ b/request_params.go @@ -0,0 +1,222 @@ +package req + +import ( + "fmt" + "net/url" + "reflect" + "strconv" + "strings" +) + +// 解决request无法使用结构体添加query和form的问题 +// 为什么要这样做?因为map是无序的,在某些情况下,参数顺序不一致,会导致请求失败和调试困难 + +// SetQueryParamsStruct 从结构体序列化为query参数,保持字段定义顺序 +// 支持的标签: `query:"name"` 或 `json:"name"` 或 `form:"name"` +// 支持 `query:"-"` 忽略字段 +// 支持 `query:"name,omitempty"` 忽略零值 +func (r *Request) SetQueryParamsStruct(params any) *Request { + if params == nil { + return r + } + + queryParams := r.marshalToUrlValues(params, "query") + if r.QueryParams == nil { + r.QueryParams = queryParams + } else { + // 合并到现有的查询参数中 + for key, values := range queryParams { + for _, value := range values { + r.QueryParams.Add(key, value) + } + } + } + return r +} + +// SetFormDataStruct 从结构体序列化为form数据,保持字段定义顺序 +// 支持的标签: `form:"name"` 或 `json:"name"` 或 `query:"name"` +// 支持 `form:"-"` 忽略字段 +// 支持 `form:"name,omitempty"` 忽略零值 +func (r *Request) SetFormDataStruct(params any) *Request { + if params == nil { + return r + } + + formData := r.marshalToUrlValues(params, "form") + if r.FormData == nil { + r.FormData = formData + } else { + // 合并到现有的表单数据中 + for key, values := range formData { + for _, value := range values { + r.FormData.Add(key, value) + } + } + } + return r +} + +// marshalToUrlValues 将结构体转换为url.Values,保持字段顺序 +func (r *Request) marshalToUrlValues(params any, primaryTag string) url.Values { + result := url.Values{} + + rv := reflect.ValueOf(params) + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return result + } + rv = rv.Elem() + } + + if rv.Kind() != reflect.Struct { + r.appendError(fmt.Errorf("params must be a struct or pointer to struct, got %T", params)) + return result + } + + rt := rv.Type() + + // 按照字段定义顺序遍历 + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + fieldValue := rv.Field(i) + + // 跳过非导出字段 + if !field.IsExported() { + continue + } + + // 获取标签名和选项 + tagName, omitempty := r.getFieldTag(field, primaryTag) + if tagName == "-" { + continue + } + + // 只有在设置了omitempty且值为零值时才跳过 + if omitempty && r.isZeroValue(fieldValue) { + continue + } + + // 转换值为字符串 + values := r.convertToStringValues(fieldValue) + for _, value := range values { + result.Add(tagName, value) + } + } + + return result +} + +// getFieldTag 获取字段的标签名和选项 +func (r *Request) getFieldTag(field reflect.StructField, primaryTag string) (string, bool) { + // 优先使用指定的标签 + if tag := field.Tag.Get(primaryTag); tag != "" { + return r.parseTag(tag, field.Name) + } + + // 回退标签顺序 + fallbackTags := []string{"json", "form", "query"} + for _, tagName := range fallbackTags { + if tagName == primaryTag { + continue + } + if tag := field.Tag.Get(tagName); tag != "" { + return r.parseTag(tag, field.Name) + } + } + + // 使用字段名的小写形式 + return strings.ToLower(field.Name), false +} + +// parseTag 解析标签,返回名称和是否有omitempty选项 +func (r *Request) parseTag(tag, fieldName string) (string, bool) { + parts := strings.Split(tag, ",") + name := strings.TrimSpace(parts[0]) + + if name == "" { + name = strings.ToLower(fieldName) + } + + omitempty := false + for i := 1; i < len(parts); i++ { + if strings.TrimSpace(parts[i]) == "omitempty" { + omitempty = true + break + } + } + + return name, omitempty +} + +// isZeroValue 检查值是否为零值 +func (r *Request) isZeroValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.String: + return v.String() == "" + case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func: + return v.IsNil() + case reflect.Array: + for i := 0; i < v.Len(); i++ { + if !r.isZeroValue(v.Index(i)) { + return false + } + } + return true + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + if v.Type().Field(i).IsExported() && !r.isZeroValue(v.Field(i)) { + return false + } + } + return true + default: + return false + } +} + +// convertToStringValues 将值转换为字符串数组 +func (r *Request) convertToStringValues(v reflect.Value) []string { + switch v.Kind() { + case reflect.String: + return []string{v.String()} + case reflect.Bool: + return []string{strconv.FormatBool(v.Bool())} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []string{strconv.FormatInt(v.Int(), 10)} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return []string{strconv.FormatUint(v.Uint(), 10)} + case reflect.Float32: + return []string{strconv.FormatFloat(v.Float(), 'f', -1, 32)} + case reflect.Float64: + return []string{strconv.FormatFloat(v.Float(), 'f', -1, 64)} + case reflect.Slice, reflect.Array: + var result []string + for i := 0; i < v.Len(); i++ { + values := r.convertToStringValues(v.Index(i)) + result = append(result, values...) + } + return result + case reflect.Ptr: + if v.IsNil() { + return []string{} + } + return r.convertToStringValues(v.Elem()) + case reflect.Interface: + if v.IsNil() { + return []string{} + } + return r.convertToStringValues(v.Elem()) + default: + // 对于其他类型,使用fmt.Sprintf + return []string{fmt.Sprintf("%v", v.Interface())} + } +} diff --git a/request_params_test.go b/request_params_test.go new file mode 100644 index 0000000..ac8fa99 --- /dev/null +++ b/request_params_test.go @@ -0,0 +1,264 @@ +package req + +import ( + "fmt" + "testing" +) + +// 测试结构体 +type QueryParams struct { + Name string `query:"name"` + Age int `query:"age,omitempty"` + Active bool `query:"active"` + Tags []string `query:"tags"` + Score float64 `query:"score"` + Ignored string `query:"-"` + NoTag string // 应该使用小写字段名 + Empty string `query:"empty,omitempty"` // 应该被忽略 +} + +type FormData struct { + Username string `form:"username"` + Password string `form:"password"` + Remember bool `form:"remember"` + Age int `form:"age,omitempty"` + Empty string `form:"empty,omitempty"` +} + +type JsonTags struct { + Name string `json:"name"` + Value int `json:"value,omitempty"` +} + +func TestSetQueryParamsMarshal(t *testing.T) { + client := C() + + // 测试基本功能 + params := QueryParams{ + Name: "john", + Age: 25, + Active: true, + Tags: []string{"tag1", "tag2"}, + Score: 98.5, + Ignored: "should_be_ignored", + NoTag: "notag_value", + Empty: "", // 应该被omitempty忽略 + } + + req := client.R().SetQueryParamsStruct(params) + + // 检查query参数 + queryParams := req.QueryParams + + // 验证基本字段 + if queryParams.Get("name") != "john" { + t.Errorf("Expected name=john, got %s", queryParams.Get("name")) + } + + if queryParams.Get("age") != "25" { + t.Errorf("Expected age=25, got %s", queryParams.Get("age")) + } + + if queryParams.Get("active") != "true" { + t.Errorf("Expected active=true, got %s", queryParams.Get("active")) + } + + if queryParams.Get("score") != "98.5" { + t.Errorf("Expected score=98.5, got %s", queryParams.Get("score")) + } + + // 验证数组字段 + tags := queryParams["tags"] + if len(tags) != 2 || tags[0] != "tag1" || tags[1] != "tag2" { + t.Errorf("Expected tags=[tag1, tag2], got %v", tags) + } + + // 验证忽略字段 + if queryParams.Get("ignored") != "" { + t.Errorf("Expected ignored field to be empty, got %s", queryParams.Get("ignored")) + } + + // 验证没有标签的字段 + if queryParams.Get("notag") != "notag_value" { + t.Errorf("Expected notag=notag_value, got %s", queryParams.Get("notag")) + } + + // 验证omitempty字段 + if queryParams.Get("empty") != "" { + t.Errorf("Expected empty field to be omitted, got %s", queryParams.Get("empty")) + } + + fmt.Printf("Query params: %v\n", queryParams) +} + +func TestSetFormDataMarshal(t *testing.T) { + client := C() + + formData := FormData{ + Username: "user123", + Password: "secret", + Remember: true, + Age: 0, // 应该被omitempty忽略 + Empty: "", // 应该被omitempty忽略 + } + + req := client.R().SetFormDataStruct(formData) + + // 检查form数据 + form := req.FormData + + if form.Get("username") != "user123" { + t.Errorf("Expected username=user123, got %s", form.Get("username")) + } + + if form.Get("password") != "secret" { + t.Errorf("Expected password=secret, got %s", form.Get("password")) + } + + if form.Get("remember") != "true" { + t.Errorf("Expected remember=true, got %s", form.Get("remember")) + } + + // 验证omitempty字段(Age为0且有omitempty标签应该被忽略) + if form.Get("age") != "" { + t.Errorf("Expected age to be omitted, got %s", form.Get("age")) + } + + if form.Get("empty") != "" { + t.Errorf("Expected empty to be omitted, got %s", form.Get("empty")) + } + + fmt.Printf("Form data: %v\n", form) +} + +func TestJsonTagsFallback(t *testing.T) { + client := C() + + params := JsonTags{ + Name: "test", + Value: 0, // 应该被omitempty忽略 + } + + req := client.R().SetQueryParamsStruct(params) + + // 应该回退到json标签 + if req.QueryParams.Get("name") != "test" { + t.Errorf("Expected name=test, got %s", req.QueryParams.Get("name")) + } + + // 验证omitempty + if req.QueryParams.Get("value") != "" { + t.Errorf("Expected value to be omitted, got %s", req.QueryParams.Get("value")) + } + + fmt.Printf("JSON tags query params: %v\n", req.QueryParams) +} + +func TestNilParams(t *testing.T) { + client := C() + + // 测试nil参数 + req := client.R().SetQueryParamsStruct(nil) + if req.QueryParams != nil && len(req.QueryParams) > 0 { + t.Errorf("Expected empty query params for nil input") + } + + req = client.R().SetFormDataStruct(nil) + if req.FormData != nil && len(req.FormData) > 0 { + t.Errorf("Expected empty form data for nil input") + } +} + +func TestPointerParams(t *testing.T) { + client := C() + + params := &QueryParams{ + Name: "pointer_test", + Active: true, + } + + req := client.R().SetQueryParamsStruct(params) + + if req.QueryParams.Get("name") != "pointer_test" { + t.Errorf("Expected name=pointer_test, got %s", req.QueryParams.Get("name")) + } +} + +func TestMergeWithExisting(t *testing.T) { + client := C() + + // 先设置一些手动的参数 + req := client.R().SetQueryParam("manual", "value") + + params := QueryParams{ + Name: "merge_test", + } + + // 再通过marshal添加参数 + req.SetQueryParamsStruct(params) + + // 验证两种参数都存在 + if req.QueryParams.Get("manual") != "value" { + t.Errorf("Expected manual=value, got %s", req.QueryParams.Get("manual")) + } + + if req.QueryParams.Get("name") != "merge_test" { + t.Errorf("Expected name=merge_test, got %s", req.QueryParams.Get("name")) + } + + fmt.Printf("Merged query params: %v\n", req.QueryParams) +} + +func TestZeroValueHandling(t *testing.T) { + client := C() + + // 测试零值处理 + type ZeroValueTest struct { + Name string `query:"name"` + Age int `query:"age"` // 零值应该被包含 + Score float64 `query:"score"` // 零值应该被包含 + Active bool `query:"active"` // false应该被包含 + EmptyWithTag string `query:"empty,omitempty"` // 零值应该被忽略 + ZeroWithTag int `query:"zero,omitempty"` // 零值应该被忽略 + } + + params := ZeroValueTest{ + Name: "", // 空字符串应该被包含 + Age: 0, // 零值应该被包含 + Score: 0.0, // 零值应该被包含 + Active: false, // false应该被包含 + EmptyWithTag: "", // 应该被忽略(omitempty) + ZeroWithTag: 0, // 应该被忽略(omitempty) + } + + req := client.R().SetQueryParamsStruct(params) + queryParams := req.QueryParams + + // 验证零值字段被包含 + if queryParams.Get("name") != "" { + t.Errorf("Expected empty name to be included, got %s", queryParams.Get("name")) + } + + if queryParams.Get("age") != "0" { + t.Errorf("Expected age=0, got %s", queryParams.Get("age")) + } + + if queryParams.Get("score") != "0" { + t.Errorf("Expected score=0, got %s", queryParams.Get("score")) + } + + if queryParams.Get("active") != "false" { + t.Errorf("Expected active=false, got %s", queryParams.Get("active")) + } + + // 验证omitempty字段被忽略 + if queryParams.Has("empty") { + t.Errorf("Expected empty field with omitempty to be omitted") + } + + if queryParams.Has("zero") { + t.Errorf("Expected zero field with omitempty to be omitted") + } + + fmt.Printf("Zero value test query params: %v\n", queryParams) +}