@@ -879,6 +879,60 @@ def test_ordered_api(self):
879879 tm .assert_index_equal (cat4 .categories , Index (['b' , 'c' , 'a' ]))
880880 assert cat4 .ordered
881881
882+ def test_set_dtype_same (self ):
883+ c = Categorical (['a' , 'b' , 'c' ])
884+ result = c ._set_dtype (CategoricalDtype (['a' , 'b' , 'c' ]))
885+ tm .assert_categorical_equal (result , c )
886+
887+ def test_set_dtype_new_categories (self ):
888+ c = Categorical (['a' , 'b' , 'c' ])
889+ result = c ._set_dtype (CategoricalDtype (['a' , 'b' , 'c' , 'd' ]))
890+ tm .assert_numpy_array_equal (result .codes , c .codes )
891+ tm .assert_index_equal (result .dtype .categories ,
892+ pd .Index (['a' , 'b' , 'c' , 'd' ]))
893+
894+ def test_set_dtype_nans (self ):
895+ c = Categorical (['a' , 'b' , np .nan ])
896+ result = c ._set_dtype (CategoricalDtype (['a' , 'c' ]))
897+ tm .assert_numpy_array_equal (result .codes , np .array ([0 , - 1 , - 1 ],
898+ dtype = 'int8' ))
899+
900+ @pytest .mark .parametrize ('values, categories, new_categories' , [
901+ # No NaNs, same cats, same order
902+ (['a' , 'b' , 'a' ], ['a' , 'b' ], ['a' , 'b' ],),
903+ # No NaNs, same cats, different order
904+ (['a' , 'b' , 'a' ], ['a' , 'b' ], ['b' , 'a' ],),
905+ # Same, unsorted
906+ (['b' , 'a' , 'a' ], ['a' , 'b' ], ['a' , 'b' ],),
907+ # No NaNs, same cats, different order
908+ (['b' , 'a' , 'a' ], ['a' , 'b' ], ['b' , 'a' ],),
909+ # NaNs
910+ (['a' , 'b' , 'c' ], ['a' , 'b' ], ['a' , 'b' ]),
911+ (['a' , 'b' , 'c' ], ['a' , 'b' ], ['b' , 'a' ]),
912+ (['b' , 'a' , 'c' ], ['a' , 'b' ], ['a' , 'b' ]),
913+ (['b' , 'a' , 'c' ], ['a' , 'b' ], ['a' , 'b' ]),
914+ # Introduce NaNs
915+ (['a' , 'b' , 'c' ], ['a' , 'b' ], ['a' ]),
916+ (['a' , 'b' , 'c' ], ['a' , 'b' ], ['b' ]),
917+ (['b' , 'a' , 'c' ], ['a' , 'b' ], ['a' ]),
918+ (['b' , 'a' , 'c' ], ['a' , 'b' ], ['a' ]),
919+ # No overlap
920+ (['a' , 'b' , 'c' ], ['a' , 'b' ], ['d' , 'e' ]),
921+ ])
922+ @pytest .mark .parametrize ('ordered' , [True , False ])
923+ def test_set_dtype_many (self , values , categories , new_categories ,
924+ ordered ):
925+ c = Categorical (values , categories )
926+ expected = Categorical (values , new_categories , ordered )
927+ result = c ._set_dtype (expected .dtype )
928+ tm .assert_categorical_equal (result , expected )
929+
930+ def test_set_dtype_no_overlap (self ):
931+ c = Categorical (['a' , 'b' , 'c' ], ['d' , 'e' ])
932+ result = c ._set_dtype (CategoricalDtype (['a' , 'b' ]))
933+ expected = Categorical ([None , None , None ], categories = ['a' , 'b' ])
934+ tm .assert_categorical_equal (result , expected )
935+
882936 def test_set_ordered (self ):
883937
884938 cat = Categorical (["a" , "b" , "c" , "a" ], ordered = True )
0 commit comments