3939)
4040from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
4141from array_api_extra ._lib ._funcs import searchsorted as _funcs_searchsorted
42- from array_api_extra ._lib ._testing import xfail , xp_assert_close , xp_assert_equal
42+ from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
4343from array_api_extra ._lib ._utils ._compat import (
4444 array_namespace ,
45- is_jax_namespace ,
4645 is_torch_namespace ,
4746)
4847from array_api_extra ._lib ._utils ._compat import device as get_device
@@ -558,8 +557,6 @@ def test_complex(self, xp: ModuleType):
558557 expect = xp .asarray ([[1.0 , - 1.0j ], [1.0j , 1.0 ]], dtype = xp .complex128 )
559558 xp_assert_close (actual , expect )
560559
561- @pytest .mark .xfail_xp_backend (Backend .JAX_GPU , reason = "jax#32296" )
562- @pytest .mark .xfail_xp_backend (Backend .JAX , reason = "jax#32296" )
563560 def test_empty (self , xp : ModuleType ):
564561 with warnings .catch_warnings (record = True ):
565562 warnings .simplefilter ("always" , RuntimeWarning )
@@ -1399,7 +1396,6 @@ def test_assume_unique(self, xp: ModuleType):
13991396 @pytest .mark .parametrize ("shape2" , [(), (1 ,), (1 , 1 )])
14001397 def test_shapes (
14011398 self ,
1402- request : pytest .FixtureRequest ,
14031399 assume_unique : bool ,
14041400 shape1 : tuple [int , ...],
14051401 shape2 : tuple [int , ...],
@@ -1408,26 +1404,18 @@ def test_shapes(
14081404 x1 = xp .zeros (shape1 )
14091405 x2 = xp .zeros (shape2 )
14101406
1411- if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
1412- xfail (request = request , reason = "jax#32335 fixed with jax>=0.8.0" )
1413-
14141407 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
14151408 xp_assert_equal (actual , xp .empty ((0 ,)))
14161409
14171410 @assume_unique
14181411 @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
1419- def test_python_scalar (
1420- self , request : pytest .FixtureRequest , xp : ModuleType , assume_unique : bool
1421- ):
1412+ def test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
14221413 # Test no dtype promotion to xp.asarray(x2); use x1.dtype
14231414 x1 = xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
14241415 x2 = 3
14251416 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
14261417 xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
14271418
1428- if is_jax_namespace (xp ) and assume_unique :
1429- xfail (request = request , reason = "jax#32335 fixed with jax>=0.8.0" )
1430-
14311419 actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
14321420 xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
14331421
0 commit comments