1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
package openapi3filter_test
import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/getkin/kin-openapi/routers/gorillamux"
)
func TestValidateCsvFileUpload(t *testing.T) {
const spec = `
openapi: 3.0.0
info:
title: 'Validator'
version: 0.0.1
paths:
/test:
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
required:
- file
properties:
file:
type: string
format: string
responses:
'200':
description: Created
`
loader := openapi3.NewLoader()
doc, err := loader.LoadFromData([]byte(spec))
require.NoError(t, err)
err = doc.Validate(loader.Context)
require.NoError(t, err)
router, err := gorillamux.NewRouter(doc)
require.NoError(t, err)
tests := []struct {
csvData string
wantErr bool
}{
{
`foo,bar`,
false,
},
{
`"foo","bar"`,
false,
},
{
`foo,bar
baz,qux`,
false,
},
{
`foo,bar
baz,qux,quux`,
true,
},
{
`"""`,
true,
},
}
for _, tt := range tests {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
{ // Add file data
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", `form-data; name="file"; filename="hello.csv"`)
h.Set("Content-Type", "text/csv")
fw, err := writer.CreatePart(h)
require.NoError(t, err)
_, err = io.Copy(fw, strings.NewReader(tt.csvData))
require.NoError(t, err)
}
writer.Close()
req, err := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes()))
require.NoError(t, err)
req.Header.Set("Content-Type", writer.FormDataContentType())
route, pathParams, err := router.FindRoute(req)
require.NoError(t, err)
if err = openapi3filter.ValidateRequestBody(
context.Background(),
&openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
},
route.Operation.RequestBody.Value,
); err != nil {
if !tt.wantErr {
t.Errorf("got %v", err)
}
continue
}
if tt.wantErr {
t.Errorf("want err")
}
}
}
|