Skip to content

Commit

Permalink
Fix CalcDistance wrong result when fetting vectors from collection (#…
Browse files Browse the repository at this point in the history
…6976)

* Fix CalcDistance wrong result when fetting vectors from collection

Signed-off-by: yhmo <[email protected]>

* Fix CalcDistance wrong result when fetting vectors from collection

Signed-off-by: yhmo <[email protected]>

* preset capacity

Signed-off-by: yhmo <[email protected]>

* typo

Signed-off-by: yhmo <[email protected]>

* error check

Signed-off-by: yhmo <[email protected]>

* code lint

Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo committed Aug 10, 2021
1 parent 3c3975b commit bdb8396
Showing 1 changed file with 110 additions and 13 deletions.
123 changes: 110 additions & 13 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,10 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
defer func() {
idsCount := 0
if rt.result != nil {
idsCount = len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)
}
log.Debug("Retrieve Done",
zap.Error(err),
zap.String("role", Params.RoleName),
Expand All @@ -1368,7 +1372,7 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
zap.Any("len(Ids)", idsCount))
}()

err = rt.WaitToFinish()
Expand Down Expand Up @@ -1593,6 +1597,80 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
return node.Retrieve(ctx, retrieveRequest)
}

// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
var retrievedIds *schemapb.ScalarField
var retrievedVectors *schemapb.VectorField
for _, fieldData := range retrievedFields {
if fieldData.FieldName == ids.FieldName {
retrievedVectors = fieldData.GetVectors()
}
if fieldData.Type == schemapb.DataType_Int64 {
retrievedIds = fieldData.GetScalars()
}
}

if retrievedIds == nil || retrievedVectors == nil {
return nil, errors.New("Failed to fetch vectors")
}

dict := make(map[int64]int)
for index, id := range retrievedIds.GetLongData().Data {
dict[id] = index
}

inputIds := ids.IdArray.GetIntId().Data
if retrievedVectors.GetFloatVector() != nil {
floatArr := retrievedVectors.GetFloatVector().Data
element := retrievedVectors.GetDim()
result := make([]float32, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := dict[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
}

return &schemapb.VectorField{
Dim: element,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: result,
},
},
}, nil
}

if retrievedVectors.GetBinaryVector() != nil {
binaryArr := retrievedVectors.GetBinaryVector()
element := retrievedVectors.GetDim()
if element%8 != 0 {
element = element + 8 - element%8
}

result := make([]byte, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := dict[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...)
}

return &schemapb.VectorField{
Dim: element * 8,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: result,
},
}, nil
}

return nil, errors.New("Failed to fetch vectors")
}

vectorsLeft := request.GetOpLeft().GetDataArray()
opLeft := request.GetOpLeft().GetIdArray()
if opLeft != nil {
Expand All @@ -1606,11 +1684,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}

for _, fieldData := range result.FieldsData {
if fieldData.FieldName == opLeft.FieldName {
vectorsLeft = fieldData.GetVectors()
break
}
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
if err != nil {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
}

Expand All @@ -1636,11 +1717,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}

for _, fieldData := range result.FieldsData {
if fieldData.FieldName == opRight.FieldName {
vectorsRight = fieldData.GetVectors()
break
}
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
if err != nil {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
}

Expand All @@ -1653,7 +1737,16 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}

if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
if vectorsLeft.Dim != vectorsRight.Dim {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Vectors dimension is not equal",
},
}, nil
}

if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric)
if err != nil {
return &milvuspb.CalcDistanceResults{
Expand All @@ -1674,7 +1767,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}

if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
if err != nil {
return &milvuspb.CalcDistanceResults{
Expand Down Expand Up @@ -1719,6 +1812,10 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}

err = errors.New("Unexpected error")
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
err = errors.New("Cannot calculate distance between binary vectors and float vectors")
}

return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Expand Down

0 comments on commit bdb8396

Please sign in to comment.