@@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config):
782782
783783 annotation_folder = tmpdir / self ._ANNOTATIONS_FOLDER
784784 os .makedirs (annotation_folder )
785+
786+ segmentation_kind = config .pop ("segmentation_kind" , "list" )
785787 info = self ._create_annotation_file (
786- annotation_folder , self ._ANNOTATIONS_FILE , file_names , num_annotations_per_image
788+ annotation_folder ,
789+ self ._ANNOTATIONS_FILE ,
790+ file_names ,
791+ num_annotations_per_image ,
792+ segmentation_kind = segmentation_kind ,
787793 )
788794
789795 info ["num_examples" ] = num_images
790796 return info
791797
792- def _create_annotation_file (self , root , name , file_names , num_annotations_per_image ):
798+ def _create_annotation_file (self , root , name , file_names , num_annotations_per_image , segmentation_kind = "list" ):
793799 image_ids = [int (file_name .stem ) for file_name in file_names ]
794800 images = [dict (file_name = str (file_name ), id = id ) for file_name , id in zip (file_names , image_ids )]
795801
796- annotations , info = self ._create_annotations (image_ids , num_annotations_per_image )
802+ annotations , info = self ._create_annotations (image_ids , num_annotations_per_image , segmentation_kind )
797803 self ._create_json (root , name , dict (images = images , annotations = annotations ))
798804
799805 return info
800806
801- def _create_annotations (self , image_ids , num_annotations_per_image ):
807+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
802808 annotations = []
803809 annotion_id = 0
810+
804811 for image_id in itertools .islice (itertools .cycle (image_ids ), len (image_ids ) * num_annotations_per_image ):
812+ segmentation = {
813+ "list" : [torch .rand (8 ).tolist ()],
814+ "rle" : {"size" : [10 , 10 ], "counts" : [1 ]},
815+ "rle_encoded" : {"size" : [2400 , 2400 ], "counts" : "PQRQ2[1\\ Y2f0gNVNRhMg2" },
816+ "bad" : 123 ,
817+ }[segmentation_kind ]
818+
805819 annotations .append (
806820 dict (
807821 image_id = image_id ,
808822 id = annotion_id ,
809823 bbox = torch .rand (4 ).tolist (),
810- segmentation = [ torch . rand ( 8 ). tolist ()] ,
824+ segmentation = segmentation ,
811825 category_id = int (torch .randint (91 , ())),
812826 area = float (torch .rand (1 )),
813827 iscrowd = int (torch .randint (2 , size = (1 ,))),
@@ -832,11 +846,27 @@ def test_slice_error(self):
832846 with pytest .raises (ValueError , match = "Index must be of type integer" ):
833847 dataset [:2 ]
834848
849+ def test_segmentation_kind (self ):
850+ if isinstance (self , CocoCaptionsTestCase ):
851+ return
852+
853+ for segmentation_kind in ("list" , "rle" , "rle_encoded" ):
854+ config = {"segmentation_kind" : segmentation_kind }
855+ with self .create_dataset (config ) as (dataset , _ ):
856+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
857+ list (dataset )
858+
859+ config = {"segmentation_kind" : "bad" }
860+ with self .create_dataset (config ) as (dataset , _ ):
861+ dataset = datasets .wrap_dataset_for_transforms_v2 (dataset , target_keys = "all" )
862+ with pytest .raises (ValueError , match = "COCO segmentation expected to be a dict or a list" ):
863+ list (dataset )
864+
835865
836866class CocoCaptionsTestCase (CocoDetectionTestCase ):
837867 DATASET_CLASS = datasets .CocoCaptions
838868
839- def _create_annotations (self , image_ids , num_annotations_per_image ):
869+ def _create_annotations (self , image_ids , num_annotations_per_image , segmentation_kind = "list" ):
840870 captions = [str (idx ) for idx in range (num_annotations_per_image )]
841871 annotations = combinations_grid (image_id = image_ids , caption = captions )
842872 for id , annotation in enumerate (annotations ):
0 commit comments