Skip to content

Commit

Permalink
fix: [2.3] Validate num of rows for insert field data with schema (#3…
Browse files Browse the repository at this point in the history
…2770) (#32845)

Cherry-pick from master
pr: #32770 
See also #32769

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia committed May 8, 2024
1 parent f72e89b commit a631856
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 1 deletion.
2 changes: 1 addition & 1 deletion internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil

default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldData(field)
n, err := funcutil.GetNumRowOfFieldDataWithSchema(field, schema)
if err != nil {
return err
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/util/funcutil/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,51 @@ func GetNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) {
return uint64((8 * int64(l)) / dim), nil
}

// GetNumRowOfFieldDataWithSchema returns num of rows with schema specification.
func GetNumRowOfFieldDataWithSchema(fieldData *schemapb.FieldData, helper *typeutil.SchemaHelper) (uint64, error) {
var fieldNumRows uint64
var err error
fieldSchema, err := helper.GetFieldFromName(fieldData.GetFieldName())
if err != nil {
return 0, err
}
switch fieldSchema.GetDataType() {
case schemapb.DataType_Bool:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetBoolData().GetData())
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetIntData().GetData())
case schemapb.DataType_Int64:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetLongData().GetData())
case schemapb.DataType_Float:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetFloatData().GetData())
case schemapb.DataType_Double:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetDoubleData().GetData())
case schemapb.DataType_String, schemapb.DataType_VarChar:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetStringData().GetData())
case schemapb.DataType_Array:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetArrayData().GetData())
case schemapb.DataType_JSON:
fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetJsonData().GetData())
case schemapb.DataType_FloatVector:
dim := fieldData.GetVectors().GetDim()
fieldNumRows, err = GetNumRowsOfFloatVectorField(fieldData.GetVectors().GetFloatVector().GetData(), dim)
if err != nil {
return 0, err
}
case schemapb.DataType_BinaryVector:
dim := fieldData.GetVectors().GetDim()
fieldNumRows, err = GetNumRowsOfBinaryVectorField(fieldData.GetVectors().GetBinaryVector(), dim)
if err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("%s is not supported now", fieldSchema.GetDataType())
}

return fieldNumRows, nil
}

// GetNumRowOfFieldData returns num of rows from the field data type
func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) {
var fieldNumRows uint64
var err error
Expand Down
234 changes: 234 additions & 0 deletions pkg/util/funcutil/func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ import (

"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
grpcCodes "google.golang.org/grpc/codes"
grpcStatus "google.golang.org/grpc/status"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

func Test_CheckGrpcReady(t *testing.T) {
Expand Down Expand Up @@ -440,3 +443,234 @@ func TestMapToJSON(t *testing.T) {
assert.NoError(t, err)
assert.True(t, reflect.DeepEqual(m, got))
}

type NumRowsWithSchemaSuite struct {
suite.Suite
helper *typeutil.SchemaHelper
}

func (s *NumRowsWithSchemaSuite) SetupSuite() {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "int8", DataType: schemapb.DataType_Int8},
{FieldID: 102, Name: "int16", DataType: schemapb.DataType_Int16},
{FieldID: 103, Name: "int32", DataType: schemapb.DataType_Int32},
{FieldID: 104, Name: "bool", DataType: schemapb.DataType_Bool},
{FieldID: 105, Name: "float", DataType: schemapb.DataType_Float},
{FieldID: 106, Name: "double", DataType: schemapb.DataType_Double},
{FieldID: 107, Name: "varchar", DataType: schemapb.DataType_VarChar},
{FieldID: 108, Name: "array", DataType: schemapb.DataType_Array},
{FieldID: 109, Name: "json", DataType: schemapb.DataType_JSON},
{FieldID: 110, Name: "float_vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}},
{FieldID: 111, Name: "binary_vector", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}},
{FieldID: 999, Name: "unknown", DataType: schemapb.DataType_None},
},
}
helper, err := typeutil.CreateSchemaHelper(schema)
s.Require().NoError(err)
s.helper = helper
}

func (s *NumRowsWithSchemaSuite) TestNormalCases() {
type testCase struct {
tag string
input *schemapb.FieldData
expect uint64
}

cases := []*testCase{
{
tag: "int64",
input: &schemapb.FieldData{
FieldName: "int64",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
},
},
expect: 3,
},
{
tag: "int8",
input: &schemapb.FieldData{
FieldName: "int8",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{1, 2, 3, 4}}}},
},
},
expect: 4,
},
{
tag: "int16",
input: &schemapb.FieldData{
FieldName: "int16",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{1, 2, 3, 4, 5}}}},
},
},
expect: 5,
},
{
tag: "int32",
input: &schemapb.FieldData{
FieldName: "int32",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{1, 2, 3, 4, 5}}}},
},
},
expect: 5,
},
{
tag: "bool",
input: &schemapb.FieldData{
FieldName: "bool",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: make([]bool, 4)}}},
},
},
expect: 4,
},
{
tag: "float",
input: &schemapb.FieldData{
FieldName: "float",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: make([]float32, 6)}}},
},
},
expect: 6,
},
{
tag: "double",
input: &schemapb.FieldData{
FieldName: "double",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: make([]float64, 8)}}},
},
},
expect: 8,
},
{
tag: "varchar",
input: &schemapb.FieldData{
FieldName: "varchar",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: make([]string, 7)}}},
},
},
expect: 7,
},
{
tag: "array",
input: &schemapb.FieldData{
FieldName: "array",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_ArrayData{ArrayData: &schemapb.ArrayArray{Data: make([]*schemapb.ScalarField, 9)}}},
},
},
expect: 9,
},
{
tag: "json",
input: &schemapb.FieldData{
FieldName: "json",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_JsonData{JsonData: &schemapb.JSONArray{Data: make([][]byte, 7)}}},
},
},
expect: 7,
},
{
tag: "float_vector",
input: &schemapb.FieldData{
FieldName: "float_vector",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 8,
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: make([]float32, 7*8)}},
},
},
},
expect: 7,
},
{
tag: "binary_vector",
input: &schemapb.FieldData{
FieldName: "binary_vector",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 8,
Data: &schemapb.VectorField_BinaryVector{BinaryVector: make([]byte, 8)},
},
},
},
expect: 8,
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
r, err := GetNumRowOfFieldDataWithSchema(tc.input, s.helper)
s.NoError(err)
s.Equal(tc.expect, r)
})
}
}

func (s *NumRowsWithSchemaSuite) TestErrorCases() {
s.Run("nil_field_data", func() {
_, err := GetNumRowOfFieldDataWithSchema(nil, s.helper)
s.Error(err)
})

s.Run("data_type_unknown", func() {
_, err := GetNumRowOfFieldDataWithSchema(&schemapb.FieldData{
FieldName: "unknown",
}, s.helper)
s.Error(err)
})

s.Run("bad_dim_vector", func() {
type testCase struct {
tag string
input *schemapb.FieldData
}

cases := []testCase{
{
tag: "float_vector",
input: &schemapb.FieldData{
FieldName: "float_vector",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 3,
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: make([]float32, 7*8)}},
},
},
},
},
{
tag: "binary_vector",
input: &schemapb.FieldData{
FieldName: "binary_vector",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 5,
Data: &schemapb.VectorField_BinaryVector{BinaryVector: make([]byte, 8)},
},
},
},
},
}

for _, tc := range cases {
s.Run(tc.tag, func() {
_, err := GetNumRowOfFieldDataWithSchema(tc.input, s.helper)
s.Error(err)
})
}
})
}

func TestNumRowsWithSchema(t *testing.T) {
suite.Run(t, new(NumRowsWithSchemaSuite))
}

0 comments on commit a631856

Please sign in to comment.