Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC2024]Import annotations keeping current ones(#4747) #7771

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- Added a feature that allows to keep current annotations without deleting them when new ones are imported, by checking the option.
(<https://github.com/cvat-ai/cvat/pull/7771>)
12 changes: 10 additions & 2 deletions cvat-ui/src/actions/import-actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export const importDatasetAsync = (
sourceStorage: Storage,
file: File | string,
convMaskToPoly: boolean,
keepOldAnnotations: boolean,
): ThunkAction => (
async (dispatch, getState) => {
const resource = instance instanceof core.classes.Project ? 'dataset' : 'annotation';
Expand All @@ -86,6 +87,7 @@ export const importDatasetAsync = (
await instance.annotations
.importDataset(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
updateStatusCallback: (message: string, progress: number) => (
dispatch(importActions.importDatasetUpdateStatus(
instance, Math.floor(progress * 100), message,
Expand All @@ -97,7 +99,10 @@ export const importDatasetAsync = (
throw Error('Only one importing of annotation/dataset allowed at the same time');
}
dispatch(importActions.importDataset(instance, format));
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, { convMaskToPoly });
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
});
} else { // job
if (state.import.tasks.dataset.current?.[instance.taskId]) {
throw Error('Annotations is being uploaded for the task');
Expand All @@ -108,7 +113,10 @@ export const importDatasetAsync = (

dispatch(importActions.importDataset(instance, format));

await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, { convMaskToPoly });
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
});
await instance.logger.log(EventScope.uploadAnnotations);
await instance.annotations.clear(true);
await instance.actions.clear();
Expand Down
34 changes: 33 additions & 1 deletion cvat-ui/src/components/import-dataset/import-dataset-modal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const initialValues: FormValues = {
interface UploadParams {
resource: 'annotation' | 'dataset' | null;
convMaskToPoly: boolean;
keepOldAnnotations: boolean;
useDefaultSettings: boolean;
sourceStorage: Storage;
selectedFormat: string | null;
Expand Down Expand Up @@ -83,6 +84,7 @@ enum ReducerActionType {
SET_FILE_NAME = 'SET_FILE_NAME',
SET_SELECTED_FORMAT = 'SET_SELECTED_FORMAT',
SET_CONV_MASK_TO_POLY = 'SET_CONV_MASK_TO_POLY',
SET_KEEP_OLD_ANNOTATIONS = 'SET_KEEP_OLD_ANNOTATIONS',
SET_SOURCE_STORAGE = 'SET_SOURCE_STORAGE',
SET_RESOURCE = 'SET_RESOURCE',
}
Expand Down Expand Up @@ -121,6 +123,9 @@ export const reducerActions = {
setConvMaskToPoly: (convMaskToPoly: boolean) => (
createAction(ReducerActionType.SET_CONV_MASK_TO_POLY, { convMaskToPoly })
),
setKeepOldAnnotations: (keepOldAnnotations: boolean) => (
createAction(ReducerActionType.SET_KEEP_OLD_ANNOTATIONS, { keepOldAnnotations })
),
setSourceStorage: (sourceStorage: Storage) => (
createAction(ReducerActionType.SET_SOURCE_STORAGE, { sourceStorage })
),
Expand Down Expand Up @@ -246,6 +251,16 @@ const reducer = (state: State, action: ActionUnion<typeof reducerActions>): Stat
};
}

if (action.type === ReducerActionType.SET_KEEP_OLD_ANNOTATIONS) {
return {
...state,
uploadParams: {
...state.uploadParams,
keepOldAnnotations: action.payload.keepOldAnnotations,
},
};
}

if (action.type === ReducerActionType.SET_SOURCE_STORAGE) {
return {
...state,
Expand Down Expand Up @@ -292,6 +307,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
uploadParams: {
resource: null,
convMaskToPoly: true,
keepOldAnnotations: true,
useDefaultSettings: true,
sourceStorage: new Storage({
location: StorageLocation.LOCAL,
Expand Down Expand Up @@ -460,6 +476,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
uploadParams.sourceStorage,
uploadParams.file || uploadParams.fileName as string,
uploadParams.convMaskToPoly,
uploadParams.keepOldAnnotations,
));
const resToPrint = uploadParams.resource.charAt(0).toUpperCase() + uploadParams.resource.slice(1);
Notification.info({
Expand Down Expand Up @@ -488,7 +505,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {

const handleImport = useCallback(
(): void => {
if (isAnnotation()) {
if (isAnnotation() && !uploadParams.keepOldAnnotations) {
confirmUpload();
} else {
onUpload();
Expand Down Expand Up @@ -538,6 +555,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
initialValues={{
...initialValues,
convMaskToPoly: uploadParams.convMaskToPoly,
keepOldAnnotations: uploadParams.keepOldAnnotations,
}}
onFinish={handleImport}
layout='vertical'
Expand Down Expand Up @@ -588,6 +606,20 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
)}
</Select>
</Form.Item>
<Space className='cvat-modal-import-switch-keep-old-annotations-container'>
<Form.Item
name='keepOldAnnotations'
valuePropName='checked'
className='cvat-modal-import-switch-keep-old-annotations'
>
<Switch
onChange={(value: boolean) => {
dispatch(reducerActions.setKeepOldAnnotations(value));
}}
/>
</Form.Item>
<Text strong>Keep Current Annotations</Text>
</Space>
<Space className='cvat-modal-import-switch-conv-mask-to-poly-container'>
<Form.Item
name='convMaskToPoly'
Expand Down
6 changes: 6 additions & 0 deletions cvat-ui/src/components/import-dataset/styles.scss
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@
.cvat-modal-import-switch-conv-mask-to-poly {
display: table-cell;
}
.cvat-modal-import-switch-keep-old-annotations {
display: table-cell;
}

.cvat-modal-import-switch-use-default-storage-container,
.cvat-modal-import-switch-conv-mask-to-poly-container {
width: 100%;
}
.cvat-modal-import-switch-keep-old-annotations-container {
width: 100%;
}
17 changes: 11 additions & 6 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,10 @@ def import_annotations(self, src_file, importer, **options):
db_job=self.db_job,
create_callback=self.create,
)
self.delete()

keep_old_annotations = options.get('keep_old_annotations', True)
if not keep_old_annotations:
self.delete()

temp_dir_base = self.db_job.get_tmp_dirname()
os.makedirs(temp_dir_base, exist_ok=True)
Expand Down Expand Up @@ -796,7 +799,9 @@ def import_annotations(self, src_file, importer, **options):
db_task=self.db_task,
create_callback=self.create,
)
self.delete()
keep_old_annotations = options.get('keep_old_annotations', True)
if not keep_old_annotations:
self.delete()

temp_dir_base = self.db_task.get_tmp_dirname()
os.makedirs(temp_dir_base, exist_ok=True)
Expand Down Expand Up @@ -910,25 +915,25 @@ def export_task(task_id, dst_file, format_name, server_url=None, save_images=Fal
task.export(f, exporter, host=server_url, save_images=save_images)

@transaction.atomic
def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly):
def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly, keep_old_annotations):
task = TaskAnnotation(task_id)
task.init_from_db()

importer = make_importer(format_name)
with open(src_file, 'rb') as f:
try:
task.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly)
task.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly, keep_old_annotations=keep_old_annotations)
except (DatasetError, DatasetImportError, DatasetNotFoundError) as ex:
raise CvatImportError(str(ex))

@transaction.atomic
def import_job_annotations(src_file, job_id, format_name, conv_mask_to_poly):
def import_job_annotations(src_file, job_id, format_name, conv_mask_to_poly, keep_old_annotations):
job = JobAnnotation(job_id)
job.init_from_db()

importer = make_importer(format_name)
with open(src_file, 'rb') as f:
try:
job.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly)
job.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly, keep_old_annotations=keep_old_annotations)
except (DatasetError, DatasetImportError, DatasetNotFoundError) as ex:
raise CvatImportError(str(ex))
28 changes: 21 additions & 7 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def upload_finished(self, request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self._object.get_tmp_dirname()
uploaded_file = None
if os.path.isfile(os.path.join(tmp_dir, filename)):
Expand All @@ -429,7 +430,8 @@ def upload_finished(self, request):
rq_func=dm.project.import_dataset_as_project,
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
elif self.action == 'import_backup':
filename = request.query_params.get("filename", "")
Expand Down Expand Up @@ -1003,6 +1005,7 @@ def _handle_upload_annotations(request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self._object.get_tmp_dirname()
if os.path.isfile(os.path.join(tmp_dir, filename)):
annotation_file = os.path.join(tmp_dir, filename)
Expand All @@ -1014,6 +1017,7 @@ def _handle_upload_annotations(request):
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations,
)
return Response(data='No such file were uploaded',
status=status.HTTP_400_BAD_REQUEST)
Expand Down Expand Up @@ -1347,19 +1351,22 @@ def annotations(self, request, pk):
elif request.method == 'POST' or request.method == 'OPTIONS':
# NOTE: initialization process of annotations import
format_name = request.query_params.get('format', '')
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
return self.import_annotations(
request=request,
db_obj=self._object,
import_func=_import_annotations,
rq_func=dm.task.import_task_annotations,
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE,
keep_old_annotations=keep_old_annotations
)
elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
if format_name:
# NOTE: continue process of import annotations
use_settings = to_bool(request.query_params.get('use_default_location', True))
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
obj = self._object if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj, use_settings=use_settings, field_name=StorageType.SOURCE
Expand All @@ -1371,7 +1378,8 @@ def annotations(self, request, pk):
db_obj=self._object,
format_name=format_name,
location_conf=location_conf,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
else:
serializer = LabeledDataSerializer(data=request.data)
Expand Down Expand Up @@ -1647,6 +1655,7 @@ def upload_finished(self, request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self.get_upload_dir()
if os.path.isfile(os.path.join(tmp_dir, filename)):
annotation_file = os.path.join(tmp_dir, filename)
Expand All @@ -1658,6 +1667,7 @@ def upload_finished(self, request):
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations,
)
else:
return Response(data='No such file were uploaded',
Expand Down Expand Up @@ -1789,19 +1799,22 @@ def annotations(self, request, pk):

elif request.method == 'POST' or request.method == 'OPTIONS':
format_name = request.query_params.get('format', '')
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
return self.import_annotations(
request=request,
db_obj=self._object,
import_func=_import_annotations,
rq_func=dm.task.import_job_annotations,
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE,
keep_old_annotations=keep_old_annotations
)

elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
if format_name:
use_settings = to_bool(request.query_params.get('use_default_location', True))
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
obj = self._object.segment.task if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj, use_settings=use_settings, field_name=StorageType.SOURCE
Expand All @@ -1813,7 +1826,8 @@ def annotations(self, request, pk):
db_obj=self._object,
format_name=format_name,
location_conf=location_conf,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
else:
serializer = LabeledDataSerializer(data=request.data)
Expand Down Expand Up @@ -2813,7 +2827,7 @@ def rq_exception_handler(rq_job, exc_type, exc_value, tb):
return True

def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
filename=None, location_conf=None, conv_mask_to_poly=True):
filename=None, location_conf=None, conv_mask_to_poly=True, keep_old_annotations=True):

format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_import_formats()}.get(format_name)
Expand Down Expand Up @@ -2882,7 +2896,7 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
filename = tf.name

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly, keep_old_annotations)

if location == Location.CLOUD_STORAGE:
func_args = (db_storage, key, func) + func_args
Expand Down