@@ -598,6 +598,44 @@ def your_function(x, y):
598
598
get_namespace = array_namespace
599
599
600
600
601
+ def _device_ctx (
602
+ bare_xp : Namespace , device : Device , like : Array | None = None
603
+ ) -> Generator [None ]:
604
+ """Context manager which changes the current device in CuPy.
605
+
606
+ Used internally by array creation functions in common._aliases.
607
+ """
608
+ if device is None :
609
+ if like is None :
610
+ return contextlib .nullcontext ()
611
+ device = _device (like )
612
+
613
+ if bare_xp is sys .modules .get ('numpy' ):
614
+ if device != "cpu" :
615
+ raise ValueError (f"Unsupported device for NumPy: { device !r} " )
616
+ return contextlib .nullcontext ()
617
+
618
+ if bare_xp is sys .modules .get ('dask.array' ):
619
+ if device not in ("cpu" , _DASK_DEVICE ):
620
+ raise ValueError (f"Unsupported device for Dask: { device !r} " )
621
+ return contextlib .nullcontext ()
622
+
623
+ if bare_xp is sys .modules .get ('cupy' ):
624
+ if not isinstance (device , bare_xp .cuda .Device ):
625
+ raise TypeError (f"device is not a cupy.cuda.Device: { device !r} " )
626
+ return device
627
+
628
+ # PyTorch doesn't have a "current device" context manager and you
629
+ # can't use array creation functions from common._aliases.
630
+ raise AssertionError ("unreachable" ) # pragma: nocover
631
+
632
+
633
+ def _check_device (bare_xp : Namespace , device : Device ) -> None :
634
+ """Validate dummy device on device-less array backends."""
635
+ with _device_ctx (bare_xp , device ):
636
+ pass
637
+
638
+
601
639
# Placeholder object to represent the dask device
602
640
# when the array backend is not the CPU.
603
641
# (since it is not easy to tell which device a dask array is on)
@@ -607,7 +645,6 @@ def __repr__(self):
607
645
608
646
_DASK_DEVICE = _dask_device ()
609
647
610
-
611
648
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
612
649
# or cupy.ndarray. They are not included in array objects of this library
613
650
# because this library just reuses the respective ndarray classes without
@@ -799,43 +836,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
799
836
return x .to_device (device , stream = stream )
800
837
801
838
802
- def _device_ctx (
803
- bare_xp : Namespace , device : Device , like : Array | None = None
804
- ) -> Generator [None ]:
805
- """Context manager which changes the current device in CuPy.
806
-
807
- Used internally by array creation functions in common._aliases.
808
- """
809
- if device is None :
810
- if like is None :
811
- return contextlib .nullcontext ()
812
- device = _device (like )
813
-
814
- if bare_xp is sys .modules .get ('numpy' ):
815
- if device != "cpu" :
816
- raise ValueError (f"Unsupported device for NumPy: { device !r} " )
817
- return contextlib .nullcontext ()
818
-
819
- if bare_xp is sys .modules .get ('dask.array' ):
820
- if device not in ("cpu" , _DASK_DEVICE ):
821
- raise ValueError (f"Unsupported device for Dask: { device !r} " )
822
- return contextlib .nullcontext ()
823
-
824
- if bare_xp is sys .modules .get ('cupy' ):
825
- if not isinstance (device , bare_xp .cuda .Device ):
826
- raise TypeError (f"device is not a cupy.cuda.Device: { device !r} " )
827
- return device
828
-
829
- # PyTorch doesn't have a "current device" context manager and you
830
- # can't use array creation functions from common._aliases.
831
- raise AssertionError ("unreachable" ) # pragma: nocover
832
-
833
-
834
- def _check_device (bare_xp : Namespace , device : Device ) -> None :
835
- with _device_ctx (bare_xp , device ):
836
- pass
837
-
838
-
839
839
def size (x : Array ) -> int | None :
840
840
"""
841
841
Return the total number of elements of x.
0 commit comments