Skip to content

Commit

Permalink
Merge pull request #1163 from emSko/1114_RLE_support_for_COCO
Browse files Browse the repository at this point in the history
1114 rle support for coco
  • Loading branch information
SkalskiP committed May 21, 2024
2 parents b916dc5 + 5840889 commit 1bbe03c
Show file tree
Hide file tree
Showing 12 changed files with 947 additions and 103 deletions.
1 change: 1 addition & 0 deletions docs/datasets.md → docs/datasets/core.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
---
comments: true
status: new
---

# Datasets
Expand Down
18 changes: 18 additions & 0 deletions docs/datasets/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
comments: true
status: new
---

# Datasets Utils

<div class="md-typeset">
<h2><a href="#supervision.dataset.utils.rle_to_mask">rle_to_mask</a></h2>
</div>

:::supervision.dataset.utils.rle_to_mask

<div class="md-typeset">
<h2><a href="#supervision.dataset.utils.mask_to_rle">mask_to_rle</a></h2>
</div>

:::supervision.dataset.utils.mask_to_rle
12 changes: 12 additions & 0 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,15 @@ status: new
</div>

:::supervision.detection.utils.pad_boxes

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.contains_holes">contains_holes</a></h2>
</div>

:::supervision.detection.utils.contains_holes

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.contains_multiple_segments">contains_multiple_segments</a></h2>
</div>

:::supervision.detection.utils.contains_multiple_segments
4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ nav:
- Detection Smoother: detection/tools/smoother.md
- Save Detections: detection/tools/save_detections.md
- Trackers: trackers.md
- Datasets: datasets.md
- Datasets:
- Core: datasets/core.md
- Utils: datasets/utils.md
- Utils:
- Video: utils/video.md
- Image: utils/image.md
Expand Down
3 changes: 3 additions & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ClassificationDataset,
DetectionDataset,
)
from supervision.dataset.utils import mask_to_rle, rle_to_mask
from supervision.detection.annotate import BoxAnnotator
from supervision.detection.core import Detections
from supervision.detection.line_zone import LineZone, LineZoneAnnotator
Expand All @@ -48,6 +49,8 @@
box_non_max_suppression,
calculate_masks_centroids,
clip_boxes,
contains_holes,
contains_multiple_segments,
filter_polygons_by_area,
mask_iou_batch,
mask_non_max_suppression,
Expand Down
34 changes: 23 additions & 11 deletions supervision/dataset/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ def split(
Tuple[DetectionDataset, DetectionDataset]: A tuple containing
the training and testing datasets.
Example:
Examples:
```python
import supervision as sv
ds = sv.DetectionDataset(...)
train_ds, test_ds = ds.split(split_ratio=0.7,
random_state=42, shuffle=True)
train_ds, test_ds = ds.split(split_ratio=0.7, random_state=42, shuffle=True)
len(train_ds), len(test_ds)
# (700, 300)
```
Expand Down Expand Up @@ -229,7 +228,7 @@ def from_pascal_voc(
DetectionDataset: A DetectionDataset instance containing
the loaded images and annotations.
Example:
Examples:
```python
import roboflow
from roboflow import Roboflow
Expand Down Expand Up @@ -286,7 +285,7 @@ def from_yolo(
DetectionDataset: A DetectionDataset instance
containing the loaded images and annotations.
Example:
Examples:
```python
import roboflow
from roboflow import Roboflow
Expand Down Expand Up @@ -391,7 +390,7 @@ def from_coco(
DetectionDataset: A DetectionDataset instance containing
the loaded images and annotations.
Example:
Examples:
```python
import roboflow
from roboflow import Roboflow
Expand Down Expand Up @@ -431,6 +430,20 @@ def as_coco(
Exports the dataset to COCO format. This method saves the
images and their corresponding annotations in COCO format.
!!! tip
The format of the mask is determined automatically based on its structure:
- If a mask contains multiple disconnected components or holes, it will be
saved using the Run-Length Encoding (RLE) format for efficient storage and
processing.
- If a mask consists of a single, contiguous region without any holes, it
will be encoded as a polygon, preserving the outline of the object.
This automatic selection ensures that the masks are stored in the most
appropriate and space-efficient format, complying with COCO dataset
standards.
Args:
images_directory_path (Optional[str]): The path to the directory
where the images should be saved.
Expand Down Expand Up @@ -482,7 +495,7 @@ def merge(cls, dataset_list: List[DetectionDataset]) -> DetectionDataset:
(DetectionDataset): A single `DetectionDataset` object containing
the merged data from the input list.
Example:
Examples:
```python
import supervision as sv
Expand Down Expand Up @@ -567,13 +580,12 @@ def split(
Tuple[ClassificationDataset, ClassificationDataset]: A tuple containing
the training and testing datasets.
Example:
Examples:
```python
import supervision as sv
cd = sv.ClassificationDataset(...)
train_cd,test_cd = cd.split(split_ratio=0.7,
random_state=42,shuffle=True)
train_cd,test_cd = cd.split(split_ratio=0.7, random_state=42,shuffle=True)
len(train_cd), len(test_cd)
# (700, 300)
```
Expand Down Expand Up @@ -635,7 +647,7 @@ def from_folder_structure(cls, root_directory_path: str) -> ClassificationDatase
Returns:
ClassificationDataset: The dataset.
Example:
Examples:
```python
import roboflow
from roboflow import Roboflow
Expand Down
73 changes: 49 additions & 24 deletions supervision/dataset/formats/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@

import cv2
import numpy as np
import numpy.typing as npt

from supervision.dataset.utils import (
approximate_mask_with_polygons,
map_detections_class_id,
mask_to_rle,
rle_to_mask,
)
from supervision.detection.core import Detections
from supervision.detection.utils import polygon_to_mask
from supervision.detection.utils import (
contains_holes,
contains_multiple_segments,
polygon_to_mask,
)
from supervision.utils.file import read_json_file, save_json_file


Expand Down Expand Up @@ -57,13 +64,24 @@ def group_coco_annotations_by_image_id(
return annotations


def _polygons_to_masks(
polygons: List[np.ndarray], resolution_wh: Tuple[int, int]
) -> np.ndarray:
def coco_annotations_to_masks(
image_annotations: List[dict], resolution_wh: Tuple[int, int]
) -> npt.NDArray[np.bool_]:
return np.array(
[
polygon_to_mask(polygon=polygon, resolution_wh=resolution_wh)
for polygon in polygons
rle_to_mask(
rle=np.array(image_annotation["segmentation"]["counts"]),
resolution_wh=resolution_wh,
)
if image_annotation["iscrowd"]
else polygon_to_mask(
polygon=np.reshape(
np.asarray(image_annotation["segmentation"], dtype=np.int32),
(-1, 2),
),
resolution_wh=resolution_wh,
)
for image_annotation in image_annotations
],
dtype=bool,
)
Expand All @@ -83,13 +101,9 @@ def coco_annotations_to_detections(
xyxy[:, 2:4] += xyxy[:, 0:2]

if with_masks:
polygons = [
np.reshape(
np.asarray(image_annotation["segmentation"], dtype=np.int32), (-1, 2)
)
for image_annotation in image_annotations
]
mask = _polygons_to_masks(polygons=polygons, resolution_wh=resolution_wh)
mask = coco_annotations_to_masks(
image_annotations=image_annotations, resolution_wh=resolution_wh
)
return Detections(
class_id=np.asarray(class_ids, dtype=int), xyxy=xyxy, mask=mask
)
Expand All @@ -108,24 +122,35 @@ def detections_to_coco_annotations(
coco_annotations = []
for xyxy, mask, _, class_id, _, _ in detections:
box_width, box_height = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1]
polygon = []
segmentation = []
iscrowd = 0
if mask is not None:
polygon = list(
approximate_mask_with_polygons(
mask=mask,
min_image_area_percentage=min_image_area_percentage,
max_image_area_percentage=max_image_area_percentage,
approximation_percentage=approximation_percentage,
)[0].flatten()
)
iscrowd = contains_holes(mask=mask) or contains_multiple_segments(mask=mask)

if iscrowd:
segmentation = {
"counts": mask_to_rle(mask=mask),
"size": list(mask.shape[:2]),
}
else:
segmentation = [
list(
approximate_mask_with_polygons(
mask=mask,
min_image_area_percentage=min_image_area_percentage,
max_image_area_percentage=max_image_area_percentage,
approximation_percentage=approximation_percentage,
)[0].flatten()
)
]
coco_annotation = {
"id": annotation_id,
"image_id": image_id,
"category_id": int(class_id),
"bbox": [xyxy[0], xyxy[1], box_width, box_height],
"area": box_width * box_height,
"segmentation": [polygon] if polygon else [],
"iscrowd": 0,
"segmentation": segmentation,
"iscrowd": iscrowd,
}
coco_annotations.append(coco_annotation)
annotation_id += 1
Expand Down

0 comments on commit 1bbe03c

Please sign in to comment.