Skip to content

Commit

Permalink
enhance: support upsert autoid==true
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Jan 29, 2024
1 parent f637526 commit cd6e32f
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 23 deletions.
3 changes: 1 addition & 2 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ def _prepare_batch_upsert_request(
entities: List,
partition_name: Optional[str] = None,
timeout: Optional[float] = None,
is_insert: bool = True,
**kwargs,
):
param = kwargs.get("upsert_param")
Expand Down Expand Up @@ -647,7 +646,7 @@ def upsert(

try:
request = self._prepare_batch_upsert_request(
collection_name, entities, partition_name, timeout, False, **kwargs
collection_name, entities, partition_name, timeout, **kwargs
)
rf = self._stub.Upsert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
Expand Down
39 changes: 34 additions & 5 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
REDUCE_STOP_FOR_BEST,
)
from .types import DataType, PlaceholderType, get_consistency_level
from .utils import traverse_info, traverse_rows_info
from .utils import traverse_insert_info, traverse_rows_info, traverse_upsert_info


class Prepare:
Expand Down Expand Up @@ -447,7 +447,7 @@ def row_upsert_param(
return cls._parse_row_request(request, fields_info, enable_dynamic, entities)

@staticmethod
def _pre_batch_check(
def _pre_insert_batch_check(
entities: List,
fields_info: Any,
):
Expand All @@ -463,7 +463,7 @@ def _pre_batch_check(
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc, auto_id_loc = traverse_info(fields_info, entities)
location, primary_key_loc, auto_id_loc = traverse_insert_info(fields_info, entities)

# though impossible from sdk
if primary_key_loc is None:
Expand All @@ -478,6 +478,35 @@ def _pre_batch_check(
raise ParamError(msg)
return location

@staticmethod
def _pre_upsert_batch_check(
entities: List,
fields_info: Any,
):
for entity in entities:
if (
not entity.get("name", None)
or not entity.get("values", None)
or not entity.get("type", None)
):
raise ParamError(
message="Missing param in entities, a field must have type, name and values"
)
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc = traverse_upsert_info(fields_info, entities)

# though impossible from sdk
if primary_key_loc is None:
raise ParamError(message="primary key not found")

if len(entities) != len(fields_info):
msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}"
raise ParamError(msg)

return location

@staticmethod
def _parse_batch_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
Expand Down Expand Up @@ -518,7 +547,7 @@ def batch_insert_param(
partition_name: str,
fields_info: Any,
):
location = cls._pre_batch_check(entities, fields_info)
location = cls._pre_insert_batch_check(entities, fields_info)
tag = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag)

Expand All @@ -532,7 +561,7 @@ def batch_upsert_param(
partition_name: str,
fields_info: Any,
):
location = cls._pre_batch_check(entities, fields_info)
location = cls._pre_upsert_batch_check(entities, fields_info)
tag = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag)

Expand Down
49 changes: 48 additions & 1 deletion pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def traverse_rows_info(fields_info: Any, entities: List):
return location, primary_key_loc, auto_id_loc


def traverse_info(fields_info: Any, entities: List):
def traverse_insert_info(fields_info: Any, entities: List):
location, primary_key_loc, auto_id_loc = {}, None, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
Expand Down Expand Up @@ -294,5 +294,52 @@ def traverse_info(fields_info: Any, entities: List):
return location, primary_key_loc, auto_id_loc


def traverse_upsert_info(fields_info: Any, entities: List):
location, primary_key_loc = {}, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
primary_key_loc = i

match_flag = False
field_name = field["name"]
field_type = field["type"]

for entity in entities:
entity_name, entity_type = entity["name"], entity["type"]

if field_name == entity_name:
if field_type != entity_type:
raise ParamError(
message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}"
)

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}"
)

if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}"
)

location[field["name"]] = i
match_flag = True
break

if not match_flag:
raise ParamError(message=f"Field {field['name']} don't match in entities")

return location, primary_key_loc


def get_server_type(host: str):
return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS
6 changes: 1 addition & 5 deletions pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ class InvalidConsistencyLevel(MilvusException):
"""Raise when consistency level is invalid"""


class UpsertAutoIDTrueException(MilvusException):
"""Raise when upsert autoID is true"""


class ExceptionsMessage:
NoHostPort = "connection configuration must contain 'host' and 'port'."
HostType = "Type of 'host' must be str."
Expand Down Expand Up @@ -211,7 +207,7 @@ class ExceptionsMessage:
InsertUnexpectedField = (
"Attempt to insert an unexpected field to collection without enabling dynamic field"
)
UpsertAutoIDTrue = "Upsert don't support autoid == true"
UpsertPrimaryKeyEmpty = "Upsert need to assign pk"
AmbiguousDeleteFilterParam = (
"Ambiguous filter parameter, only one deletion condition can be specified."
)
Expand Down
34 changes: 30 additions & 4 deletions pymilvus/orm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DataNotMatchException,
DataTypeNotSupportException,
ExceptionsMessage,
UpsertAutoIDTrueException,
)

from .schema import CollectionSchema
Expand Down Expand Up @@ -82,7 +81,34 @@ def prepare_upsert_data(
data: Union[List, Tuple, pd.DataFrame],
schema: CollectionSchema,
) -> List:
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)
if not isinstance(data, (list, tuple, pd.DataFrame)):
raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport)

fields = schema.fields
entities = [] # Entities

if isinstance(data, pd.DataFrame):
if schema.primary_field.name in data and data[schema.primary_field.name].isnull().all():
raise DataNotMatchException(message=ExceptionsMessage.UpsertPrimaryKeyEmpty)
for field in fields:
values = []
if field.name in list(data.columns):
values = list(data[field.name])
entities.append({"name": field.name, "type": field.dtype, "values": values})
return entities

tmp_fields = copy.deepcopy(fields)

for i, field in enumerate(tmp_fields):
try:
if isinstance(data[i], np.ndarray):
d = data[i].tolist()
else:
d = data[i] if data[i] is not None else []

return cls.prepare_insert_data(data, schema)
entities.append({"name": field.name, "type": field.dtype, "values": d})
# the last missing part of data is also completed in order according to the schema
except IndexError:
entities.append({"name": field.name, "type": field.dtype, "values": []})

return entities
34 changes: 28 additions & 6 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
PartitionKeyException,
PrimaryKeyException,
SchemaNotReadyException,
UpsertAutoIDTrueException,
)

from .constants import COMMON_TYPE_PARAMS
Expand Down Expand Up @@ -417,7 +416,7 @@ def _check_insert_data(data: Union[List[List], pd.DataFrame]):
raise DataTypeNotSupportException(message="data should be a list of list")


def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
def _check_insert_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
Expand Down Expand Up @@ -456,17 +455,40 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat
columns.remove(schema.primary_field)
data = data[[columns]]

_check_data_schema_cnt(schema, data)
_check_insert_data_schema_cnt(schema, data)
_check_insert_data(data)


def _check_upsert_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
tmp_fields = copy.deepcopy(schema.fields)
# upsert must assign pk even when autoid==true
field_cnt = len(tmp_fields)
is_dataframe = isinstance(data, pd.DataFrame)
data_cnt = len(data.columns) if is_dataframe else len(data)
if field_cnt != data_cnt:
message = (
f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}"
)
if is_dataframe:
i_name = [f.name for f in tmp_fields]
t_name = list(data.columns)
message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}"

raise DataNotMatchException(message=message)

if is_dataframe:
for x, y in zip(list(data.columns), tmp_fields):
if x != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x}"
)


def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if schema is None:
raise SchemaNotReadyException(message="Schema shouldn't be None")
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)

_check_data_schema_cnt(schema, data)
_check_upsert_data_schema_cnt(schema, data)
_check_insert_data(data)


Expand Down

0 comments on commit cd6e32f

Please sign in to comment.