diff --git a/onedal/__init__.py b/onedal/__init__.py index 8f7ade667c..083fb5a676 100644 --- a/onedal/__init__.py +++ b/onedal/__init__.py @@ -19,6 +19,23 @@ from daal4py.sklearn._utils import daal_check_version + +class Backend: + """Encapsulates the backend module and provides a unified interface to it together with additional properties about dpc/spmd policies""" + + def __init__(self, backend_module, is_dpc, is_spmd): + self.backend = backend_module + self.is_dpc = is_dpc + self.is_spmd = is_spmd + + # accessing the instance will return the backend_module + def __getattr__(self, name): + return getattr(self.backend, name) + + def __repr__(self) -> str: + return f"Backend({self.backend}, is_dpc={self.is_dpc}, is_spmd={self.is_spmd})" + + if "Windows" in platform.system(): import os import site @@ -40,44 +57,67 @@ pass os.environ["PATH"] = path_to_libs + os.pathsep + os.environ["PATH"] -try: - import onedal._onedal_py_dpc as _backend - - _is_dpc_backend = True -except ImportError: - import onedal._onedal_py_host as _backend - - _is_dpc_backend = False -_is_spmd_backend = False +try: + # use dpc backend if available + import onedal._onedal_py_dpc -if _is_dpc_backend: - try: - import onedal._onedal_py_spmd_dpc as _spmd_backend + _dpc_backend = Backend(onedal._onedal_py_dpc, is_dpc=True, is_spmd=False) - _is_spmd_backend = True - except ImportError: - _is_spmd_backend = False + _host_backend = None +except ImportError: + # fall back to host backend + _dpc_backend = None + import onedal._onedal_py_host -__all__ = ["covariance", "decomposition", "ensemble", "neighbors", "primitives", "svm"] + _host_backend = Backend(onedal._onedal_py_host, is_dpc=False, is_spmd=False) -if _is_spmd_backend: - __all__.append("spmd") +try: + # also load spmd backend if available + import onedal._onedal_py_spmd_dpc + _spmd_backend = Backend(onedal._onedal_py_spmd_dpc, is_dpc=True, is_spmd=True) +except ImportError: + _spmd_backend = None + +# if/elif/else layout required for pylint to realize _default_backend cannot be None +if _dpc_backend is not None: + _default_backend = _dpc_backend +elif _host_backend is not None: + _default_backend = _host_backend +else: + raise ImportError("No oneDAL backend available") + +# Core modules to export +__all__ = [ + "_host_backend", + "_default_backend", + "_dpc_backend", + "_spmd_backend", + "covariance", + "decomposition", + "ensemble", + "neighbors", + "primitives", + "svm", +] + +# Additional features based on version checks if daal_check_version((2023, "P", 100)): __all__ += ["basic_statistics", "linear_model"] +if daal_check_version((2023, "P", 200)): + __all__ += ["cluster"] - if _is_spmd_backend: +# Exports if SPMD backend is available +if _spmd_backend is not None: + __all__ += ["spmd"] + if daal_check_version((2023, "P", 100)): __all__ += [ "spmd.basic_statistics", "spmd.decomposition", "spmd.linear_model", "spmd.neighbors", ] - -if daal_check_version((2023, "P", 200)): - __all__ += ["cluster"] - - if _is_spmd_backend: + if daal_check_version((2023, "P", 200)): __all__ += ["spmd.cluster"] diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 4e46592bb2..7b0a327fd2 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -14,11 +14,13 @@ # limitations under the License. # ============================================================================== -import logging +import inspect from collections.abc import Iterable +from contextlib import contextmanager from functools import wraps import numpy as np +from scipy import sparse as sp from sklearn import get_config from ._config import _get_config @@ -30,13 +32,153 @@ from dpctl.memory import MemoryUSMDevice, as_usm_memory from dpctl.tensor import usm_ndarray else: - import onedal + from onedal import _dpc_backend + + SyclQueue = getattr(_dpc_backend, "SyclQueue", None) + + +class SyclQueueManager: + """Manage global and data SyclQueues""" + + # single instance of global queue + __global_queue = None + + @staticmethod + def __create_sycl_queue(target): + if SyclQueue is None: + # we don't have SyclQueue support + return None + if target is None: + return None + if isinstance(target, SyclQueue): + return target + if isinstance(target, (str, int)): + return SyclQueue(target) + raise ValueError(f"Invalid queue or device selector {target=}.") + + @staticmethod + def get_global_queue(): + """Get the global queue. Retrieve it from the config if not set.""" + if (queue := SyclQueueManager.__global_queue) is not None: + if not isinstance(queue, SyclQueue): + raise ValueError("Global queue is not a SyclQueue object.") + return queue + + target = _get_config()["target_offload"] + if target == "auto": + # queue will be created from the provided data to each function call + return None + + q = SyclQueueManager.__create_sycl_queue(target) + SyclQueueManager.update_global_queue(q) + return q + + @staticmethod + def remove_global_queue(): + """Remove the global queue.""" + SyclQueueManager.__global_queue = None + + @staticmethod + def update_global_queue(queue): + """Update the global queue.""" + queue = SyclQueueManager.__create_sycl_queue(queue) + SyclQueueManager.__global_queue = queue + + @staticmethod + def from_data(*data): + """Extract the queue from provided data. This updates the global queue as well.""" + for item in data: + # iterate through all data objects, extract the queue, and verify that all data objects are on the same device + + # get the `usm_interface` - the C++ implementation might throw an exception if the data type is not supported + try: + usm_iface = getattr(item, "__sycl_usm_array_interface__", None) + except RuntimeError as e: + if "SUA interface" in str(e): + # ignore SUA interface errors and move on + continue + else: + # unexpected, re-raise + raise e + + if usm_iface is None: + # no interface found - try next data object + continue + + # extract the queue + global_queue = SyclQueueManager.get_global_queue() + data_queue = usm_iface["syclobj"] + if not data_queue: + # no queue, i.e. host data, no more work to do + continue + + # update the global queue if not set + if global_queue is None: + SyclQueueManager.update_global_queue(data_queue) + global_queue = data_queue + + # if either queue points to a device, assert it's always the same device + data_dev = data_queue.sycl_device + global_dev = global_queue.sycl_device + if (data_dev and global_dev) is not None and data_dev != global_dev: + raise ValueError( + "Data objects are located on different target devices or not on selected device." + ) + + # after we went through the data, global queue is updated and verified (if any queue found) + return SyclQueueManager.get_global_queue() + + @staticmethod + @contextmanager + def manage_global_queue(queue, *args): + """ + Context manager to manage the global SyclQueue. + + This context manager updates the global queue with the provided queue, + verifies that all data objects are on the same device, and restores the + original queue after work is done. + Note: For most applications, the original queue should be `None`, but + if there are nested calls to `manage_global_queue()`, it is + important to restore the outer queue, rather than setting it to + `None`. + + Parameters: + queue (SyclQueue or None): The queue to set as the global queue. If None, + the global queue will be determined from the provided data. + *args: Additional data objects to verify their device placement. + + Yields: + SyclQueue: The global queue after verification. + """ + original_queue = SyclQueueManager.get_global_queue() + try: + # update the global queue with what is provided, it can be None, then we will get it from provided data + SyclQueueManager.update_global_queue(queue) + # find the queues in data using SyclQueueManager to verify that all data objects are on the same device + yield SyclQueueManager.from_data(*args) + finally: + # restore the original queue + SyclQueueManager.update_global_queue(original_queue) + + +def supports_queue(func): + """ + Decorator that updates the global queue based on provided queue and global configuration. + If a `queue` keyword argument is provided in the decorated function, its value will be used globally. + If no queue is provided, the global queue will be updated from the provided data. + In either case, all data objects are verified to be on the same device (or on host). + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + queue = kwargs.get("queue", None) + with SyclQueueManager.manage_global_queue(queue, *args) as queue: + kwargs["queue"] = queue + result = func(self, *args, **kwargs) + return result + + return wrapper - # setting fallback to `object` will make if isinstance call - # in _get_global_queue always true for situations without the - # dpc backend when `device_offload` is used. Instead, it will - # fail at the policy check phase yielding a RuntimeError - SyclQueue = getattr(onedal._backend, "SyclQueue", object) if dpnp_available: import dpnp @@ -69,7 +211,7 @@ def _copy_to_usm(queue, array): return array -def _transfer_to_host(queue, *data): +def _transfer_to_host(*data): has_usm_data, has_host_data = False, False host_data = [] @@ -82,13 +224,6 @@ def _transfer_to_host(queue, *data): "dpctl need to be installed to work " "with __sycl_usm_array_interface__" ) - if queue is not None: - if queue.sycl_device != usm_iface["syclobj"].sycl_device: - raise RuntimeError( - "Input data shall be located " "on single target device" - ) - else: - queue = usm_iface["syclobj"] buffer = as_usm_memory(item).copy_to_host() order = "C" @@ -117,88 +252,62 @@ def _transfer_to_host(queue, *data): raise RuntimeError("Input data shall be located on single target device") host_data.append(item) - return has_usm_data, queue, host_data - - -def _get_global_queue(): - target = _get_config()["target_offload"] - - if target != "auto": - if isinstance(target, SyclQueue): - return target - return SyclQueue(target) - return None + return has_usm_data, host_data def _get_host_inputs(*args, **kwargs): - q = _get_global_queue() - _, q, hostargs = _transfer_to_host(q, *args) - _, q, hostvalues = _transfer_to_host(q, *kwargs.values()) + _, hostargs = _transfer_to_host(*args) + _, hostvalues = _transfer_to_host(*kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) - return q, hostargs, hostkwargs - + return hostargs, hostkwargs -def _run_on_device(func, obj=None, *args, **kwargs): - if obj is not None: - return func(obj, *args, **kwargs) - return func(*args, **kwargs) - -def support_input_format(freefunc=False, queue_param=True): +def support_input_format(func): """ Converts and moves the output arrays of the decorated function to match the input array type and device. Puts SYCLQueue from data to decorated function arguments. + """ - Parameters - ---------- - freefunc (bool) : Set to True if decorates free function. - queue_param (bool) : Set to False if the decorated function has no `queue` parameter + def invoke_func(self_or_None, *args, **kwargs): + if self_or_None is None: + return func(*args, **kwargs) + else: + return func(self_or_None, *args, **kwargs) + + @wraps(func) + def wrapper_impl(*args, **kwargs): + # remove self from args if it is a class method + if inspect.isfunction(func) and "." in func.__qualname__: + self = args[0] + args = args[1:] + else: + self = None - Notes - ----- - Queue will not be changed if provided explicitly. - """ + if len(args) == 0 and len(kwargs) == 0: + return invoke_func(self, *args, **kwargs) + + data = (*args, *kwargs.values()) + # get and set the global queue from the kwarg or data + with SyclQueueManager.manage_global_queue(kwargs.get("queue"), *args) as queue: + hostargs, hostkwargs = _get_host_inputs(*args, **kwargs) + if "queue" in inspect.signature(func).parameters: + # set the queue if it's expected by func + hostkwargs["queue"] = queue + result = invoke_func(self, *hostargs, **hostkwargs) - def decorator(func): - def wrapper_impl(obj, *args, **kwargs): - if len(args) == 0 and len(kwargs) == 0: - return _run_on_device(func, obj, *args, **kwargs) - data = (*args, *kwargs.values()) - data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs) - if queue_param and not ( - "queue" in hostkwargs and hostkwargs["queue"] is not None - ): - hostkwargs["queue"] = data_queue - result = _run_on_device(func, obj, *hostargs, **hostkwargs) usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None) - if usm_iface is not None: - result = _copy_to_usm(data_queue, result) + if queue is not None and usm_iface is not None: + result = _copy_to_usm(queue, result) if dpnp_available and isinstance(data[0], dpnp.ndarray): result = _convert_to_dpnp(result) return result - config = get_config() - if not ("transform_output" in config and config["transform_output"]): - input_array_api = getattr(data[0], "__array_namespace__", lambda: None)() - if input_array_api: - input_array_api_device = data[0].device - result = _asarray( - result, input_array_api, device=input_array_api_device - ) - return result - - if freefunc: - - @wraps(func) - def wrapper_free(*args, **kwargs): - return wrapper_impl(None, *args, **kwargs) - - return wrapper_free - - @wraps(func) - def wrapper_with_self(self, *args, **kwargs): - return wrapper_impl(self, *args, **kwargs) - return wrapper_with_self + if not get_config().get("transform_output"): + input_array_api = getattr(data[0], "__array_namespace__", lambda: None)() + if input_array_api: + input_array_api_device = data[0].device + result = _asarray(result, input_array_api, device=input_array_api_device) + return result - return decorator + return wrapper_impl diff --git a/onedal/basic_statistics/basic_statistics.py b/onedal/basic_statistics/basic_statistics.py index 56904adce2..524091d078 100644 --- a/onedal/basic_statistics/basic_statistics.py +++ b/onedal/basic_statistics/basic_statistics.py @@ -14,23 +14,26 @@ # limitations under the License. # ============================================================================== -import warnings from abc import ABCMeta, abstractmethod import numpy as np -from ..common._base import BaseEstimator +from onedal._device_offload import supports_queue + +from ..common._backend import bind_default_backend from ..datatypes import from_table, to_table -from ..utils import _is_csr -from ..utils.validation import _check_array +from ..utils.validation import _check_array, _is_csr -class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta): +class BaseBasicStatistics(metaclass=ABCMeta): @abstractmethod def __init__(self, result_options, algorithm): self.options = result_options self.algorithm = algorithm + @bind_default_backend("basic_statistics") + def compute(self, params, data_table, weights_table): ... + @staticmethod def get_all_result_options(): return [ @@ -71,9 +74,8 @@ class BasicStatistics(BaseBasicStatistics): def __init__(self, result_options="all", algorithm="by_default"): super().__init__(result_options, algorithm) + @supports_queue def fit(self, data, sample_weight=None, queue=None): - policy = self._get_policy(queue, data, sample_weight) - is_csr = _is_csr(data) if data is not None and not is_csr: @@ -85,7 +87,9 @@ def fit(self, data, sample_weight=None, queue=None): data_table, weights_table = to_table(data, sample_weight, queue=queue) dtype = data_table.dtype - raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr) + raw_result = raw_result = self._compute_raw( + data_table, weights_table, dtype, is_csr + ) for opt, raw_value in raw_result.items(): value = from_table(raw_value).ravel() if is_single_dim: @@ -95,12 +99,9 @@ def fit(self, data, sample_weight=None, queue=None): return self - def _compute_raw( - self, data_table, weights_table, policy, dtype=np.float32, is_csr=False - ): - module = self._get_backend("basic_statistics") + def _compute_raw(self, data_table, weights_table, dtype=np.float32, is_csr=False): params = self._get_onedal_params(is_csr, dtype) - result = module.compute(policy, params, data_table, weights_table) + result = self.compute(params, data_table, weights_table) options = self._get_result_options(self.options).split("|") return {opt: getattr(result, opt) for opt in options} diff --git a/onedal/basic_statistics/incremental_basic_statistics.py b/onedal/basic_statistics/incremental_basic_statistics.py index b98161ce59..a5da78e5f1 100644 --- a/onedal/basic_statistics/incremental_basic_statistics.py +++ b/onedal/basic_statistics/incremental_basic_statistics.py @@ -17,9 +17,11 @@ import numpy as np from daal4py.sklearn._utils import get_dtype +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array from .basic_statistics import BaseBasicStatistics @@ -68,12 +70,22 @@ class IncrementalBasicStatistics(BaseBasicStatistics): def __init__(self, result_options="all"): super().__init__(result_options, algorithm="by_default") self._reset() + self._queue = None + + @bind_default_backend("basic_statistics") + def partial_compute_result(self): ... + + @bind_default_backend("basic_statistics") + def partial_compute(self, *args, **kwargs): ... + + @bind_default_backend("basic_statistics") + def finalize_compute(self, *args, **kwargs): ... def _reset(self): self._need_to_finalize = False - self._partial_result = self._get_backend( - "basic_statistics", None, "partial_compute_result" - ) + self._queue = None + # get the _partial_result pointer from backend + self._partial_result = self.partial_compute_result() def __getstate__(self): # Since finalize_fit can't be dispatched without directly provided queue @@ -85,6 +97,7 @@ def __getstate__(self): return data + @supports_queue def partial_fit(self, X, weights=None, queue=None): """ Computes partial data for basic statistics @@ -105,7 +118,6 @@ def partial_fit(self, X, weights=None, queue=None): Returns the instance itself. """ self._queue = queue - policy = self._get_policy(queue, X) X = _check_array( X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False @@ -123,21 +135,14 @@ def partial_fit(self, X, weights=None, queue=None): self._onedal_params = self._get_onedal_params(False, dtype=dtype) X_table, weights_table = to_table(X, weights, queue=queue) - self._partial_result = self._get_backend( - "basic_statistics", - None, - "partial_compute", - policy, - self._onedal_params, - self._partial_result, - X_table, - weights_table, + self._partial_result = self.partial_compute( + self._onedal_params, self._partial_result, X_table, weights_table ) self._need_to_finalize = True - return self + self._queue = queue - def finalize_fit(self, queue=None): + def finalize_fit(self): """ Finalizes basic statistics computation and obtains result attributes from the current `_partial_result`. @@ -153,19 +158,9 @@ def finalize_fit(self, queue=None): Returns the instance itself. """ if self._need_to_finalize: - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - - result = self._get_backend( - "basic_statistics", - None, - "finalize_compute", - policy, - self._onedal_params, - self._partial_result, - ) + with SyclQueueManager.manage_global_queue(self._queue): + result = self.finalize_compute(self._onedal_params, self._partial_result) + options = self._get_result_options(self.options).split("|") for opt in options: setattr(self, opt, from_table(getattr(result, opt)).ravel()) diff --git a/onedal/cluster/dbscan.py b/onedal/cluster/dbscan.py index 02dcfb6a58..6c009f4669 100644 --- a/onedal/cluster/dbscan.py +++ b/onedal/cluster/dbscan.py @@ -17,14 +17,15 @@ import numpy as np from daal4py.sklearn._utils import get_dtype, make2d +from onedal._device_offload import supports_queue +from onedal.common._backend import bind_default_backend -from ..common._base import BaseEstimator from ..common._mixin import ClusterMixin from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array -class BaseDBSCAN(BaseEstimator, ClusterMixin): +class DBSCAN(ClusterMixin): def __init__( self, eps=0.5, @@ -46,6 +47,9 @@ def __init__( self.p = p self.n_jobs = n_jobs + @bind_default_backend("dbscan.clustering") + def compute(self, params, data_table, weights_table): ... + def _get_onedal_params(self, dtype=np.float32): return { "fptype": dtype, @@ -56,14 +60,14 @@ def _get_onedal_params(self, dtype=np.float32): "result_options": "core_observation_indices|responses", } - def _fit(self, X, y, sample_weight, module, queue): - policy = self._get_policy(queue, X) + @supports_queue + def fit(self, X, y=None, sample_weight=None, queue=None): X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) sample_weight = make2d(sample_weight) if sample_weight is not None else None X_table, sample_weight_table = to_table(X, sample_weight, queue=queue) params = self._get_onedal_params(X_table.dtype) - result = module.compute(policy, params, X_table, sample_weight_table) + result = self.compute(params, X_table, sample_weight_table) self.labels_ = from_table(result.responses).ravel() if result.core_observation_indices is not None: @@ -75,31 +79,3 @@ def _fit(self, X, y, sample_weight, module, queue): self.components_ = np.take(X, self.core_sample_indices_, axis=0) self.n_features_in_ = X.shape[1] return self - - -class DBSCAN(BaseDBSCAN): - def __init__( - self, - eps=0.5, - *, - min_samples=5, - metric="euclidean", - metric_params=None, - algorithm="auto", - leaf_size=30, - p=None, - n_jobs=None, - ): - self.eps = eps - self.min_samples = min_samples - self.metric = metric - self.metric_params = metric_params - self.algorithm = algorithm - self.leaf_size = leaf_size - self.p = p - self.n_jobs = n_jobs - - def fit(self, X, y=None, sample_weight=None, queue=None): - return super()._fit( - X, y, sample_weight, self._get_backend("dbscan", "clustering", None), queue - ) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index a40729841d..de2c2012b0 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -20,9 +20,10 @@ import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype -from onedal import _backend +from daal4py.sklearn._utils import daal_check_version +from onedal._device_offload import SyclQueueManager, supports_queue from onedal.basic_statistics import BasicStatistics +from onedal.common._backend import bind_default_backend if daal_check_version((2023, "P", 200)): from .kmeans_init import KMeansInit @@ -32,13 +33,14 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.utils import check_random_state -from ..common._base import BaseEstimator as onedal_BaseEstimator +from onedal import _default_backend + from ..common._mixin import ClusterMixin, TransformerMixin from ..datatypes import from_table, to_table -from ..utils import _check_array, _is_arraylike_not_scalar, _is_csr +from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr -class _BaseKMeans(onedal_BaseEstimator, TransformerMixin, ClusterMixin, ABC): +class _BaseKMeans(TransformerMixin, ClusterMixin, ABC): def __init__( self, n_clusters, @@ -60,6 +62,15 @@ def __init__( self.random_state = random_state self.n_local_trials = n_local_trials + @bind_default_backend("kmeans_common", no_policy=True) + def _is_same_clustering(self, labels, best_labels, n_clusters): ... + + @bind_default_backend("kmeans.clustering") + def train(self, params, X_table, centroids_table): ... + + @bind_default_backend("kmeans.clustering") + def infer(self, params, model, centroids_table): ... + def _validate_center_shape(self, X, centers): """Check if centers is compatible with X and n_clusters.""" if centers.shape[0] != self.n_clusters: @@ -80,7 +91,7 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm): def _get_basic_statistics_backend(self, result_options): return BasicStatistics(result_options) - def _tolerance(self, X_table, rtol, is_csr, policy, dtype): + def _tolerance(self, X_table, rtol, is_csr, dtype): """Compute absolute tolerance from the relative tolerance""" if rtol == 0.0: return rtol @@ -88,13 +99,13 @@ def _tolerance(self, X_table, rtol, is_csr, policy, dtype): bs = self._get_basic_statistics_backend("variance") - res = bs._compute_raw(X_table, dummy, policy, dtype, is_csr) + res = bs._compute_raw(X_table, dummy, dtype, is_csr) mean_var = from_table(res["variance"]).mean() return mean_var * rtol def _check_params_vs_input( - self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32 + self, X_table, is_csr, default_n_init=10, dtype=np.float32 ): # n_clusters if X_table.shape[0] < self.n_clusters: @@ -103,7 +114,7 @@ def _check_params_vs_input( ) # tol - self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype) + self._tol = self._tolerance(X_table, self.tol, is_csr, dtype) # n-init # TODO(1.4): Remove @@ -159,42 +170,24 @@ def _init_centroids_onedal( X_table, init, random_seed, - policy, is_csr, dtype=np.float32, n_centroids=None, ): n_clusters = self.n_clusters if n_centroids is None else n_centroids - # Use host policy for KMeans init, only for csr data - # as oneDAL KMeansInit for CSR data is not implemented on GPU - if is_csr: - init_policy = self._get_policy(None, None) - logging.getLogger("sklearnex").info("Running Sparse KMeansInit on CPU") - else: - init_policy = policy if isinstance(init, str) and init == "k-means++": - if not is_csr: - alg = self._get_kmeans_init( - cluster_count=n_clusters, - seed=random_seed, - algorithm="plus_plus_dense", - ) - else: - alg = self._get_kmeans_init( - cluster_count=n_clusters, seed=random_seed, algorithm="plus_plus_csr" - ) - centers_table = alg.compute_raw(X_table, init_policy, dtype) + algorithm = "plus_plus_dense" if not is_csr else "plus_plus_csr" + alg = self._get_kmeans_init( + cluster_count=n_clusters, seed=random_seed, algorithm=algorithm + ) + centers_table = alg.compute_raw(X_table, dtype) elif isinstance(init, str) and init == "random": - if not is_csr: - alg = self._get_kmeans_init( - cluster_count=n_clusters, seed=random_seed, algorithm="random_dense" - ) - else: - alg = self._get_kmeans_init( - cluster_count=n_clusters, seed=random_seed, algorithm="random_csr" - ) - centers_table = alg.compute_raw(X_table, init_policy, dtype) + algorithm = "random_dense" if not is_csr else "random_csr" + alg = self._get_kmeans_init( + cluster_count=n_clusters, seed=random_seed, algorithm=algorithm + ) + centers_table = alg.compute_raw(X_table, dtype) elif _is_arraylike_not_scalar(init): if _is_csr(init): # oneDAL KMeans only supports Dense Centroids @@ -205,13 +198,13 @@ def _init_centroids_onedal( assert centers.shape[1] == X_table.column_count # KMeans is implemented on both CPU and GPU for Dense and CSR data # The original policy can be used here - centers_table = to_table(centers, queue=getattr(policy, "_queue", None)) + centers_table = to_table(centers, queue=SyclQueueManager.get_global_queue()) else: raise TypeError("Unsupported type of the `init` value") return centers_table - def _init_centroids_sklearn(self, X, init, random_state, policy, dtype=np.float32): + def _init_centroids_sklearn(self, X, init, random_state, dtype=np.float32): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") @@ -239,17 +232,17 @@ def _init_centroids_sklearn(self, X, init, random_state, policy, dtype=np.float3 f"callable, got '{ init }' instead." ) - return to_table(centers, queue=getattr(policy, "_queue", None)) + return to_table( + centers, queue=getattr(SyclQueueManager.get_global_queue(), "_queue", None) + ) - def _fit_backend( - self, X_table, centroids_table, module, policy, dtype=np.float32, is_csr=False - ): + def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False): params = self._get_onedal_params(is_csr, dtype) - meta = _backend.get_table_metadata(X_table) + meta = _default_backend.get_table_metadata(X_table) assert meta.get_npy_dtype(0) == dtype - result = module.train(policy, params, X_table, centroids_table) + result = self.train(params, X_table, centroids_table) return ( result.responses, @@ -258,16 +251,15 @@ def _fit_backend( result.iteration_count, ) - def _fit(self, X, module, queue=None): - policy = self._get_policy(queue, X) + def _fit(self, X): is_csr = _is_csr(X) X = _check_array( X, dtype=[np.float64, np.float32], accept_sparse="csr", force_all_finite=False ) - X_table = to_table(X, queue=queue) + X_table = to_table(X, queue=SyclQueueManager.get_global_queue()) dtype = X_table.dtype - self._check_params_vs_input(X_table, is_csr, policy, dtype=dtype) + self._check_params_vs_input(X_table, is_csr, dtype=dtype) self.n_features_in_ = X_table.column_count @@ -278,12 +270,10 @@ def is_better_iteration(inertia, labels): if best_inertia is None: return True else: - mod = self._get_backend("kmeans_common", None, None) better_inertia = inertia < best_inertia - same_clusters = mod._is_same_clustering( + return better_inertia and not self._is_same_clustering( labels, best_labels, self.n_clusters ) - return better_inertia and not same_clusters random_state = check_random_state(self.random_state) @@ -301,18 +291,18 @@ def is_better_iteration(inertia, labels): if use_onedal_init: random_seed = random_state.randint(np.iinfo("i").max) centroids_table = self._init_centroids_onedal( - X_table, init, random_seed, policy, is_csr, dtype=dtype + X_table, init, random_seed, is_csr, dtype=dtype ) else: centroids_table = self._init_centroids_sklearn( - X, init, random_state, policy, dtype=dtype + X, init, random_state, dtype=dtype ) if self.verbose: print("Initialization complete") labels, inertia, model, n_iter = self._fit_backend( - X_table, centroids_table, module, policy, dtype, is_csr + X_table, centroids_table, dtype, is_csr ) if self.verbose: @@ -351,7 +341,7 @@ def cluster_centers_(self): centroids = self.model_.centroids self._cluster_centers_ = from_table(centroids) else: - raise NameError("This model have not been trained") + raise NameError("This model has not been trained") return self._cluster_centers_ @cluster_centers_.setter @@ -361,7 +351,6 @@ def cluster_centers_(self, cluster_centers): self.n_iter_ = 0 self.inertia_ = 0 - self.model_ = self._get_backend("kmeans", "clustering", "model") self.model_.centroids = to_table(self._cluster_centers_) self.n_features_in_ = self.model_.centroids.column_count self.labels_ = np.arange(self.model_.centroids.row_count) @@ -372,27 +361,26 @@ def cluster_centers_(self, cluster_centers): def cluster_centers_(self): del self._cluster_centers_ - def _predict(self, X, module, queue=None, result_options=None): + def _predict(self, X, result_options=None): is_csr = _is_csr(X) - policy = self._get_policy(queue, X) - X_table = to_table(X, queue=queue) + X_table = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(is_csr, X_table.dtype, result_options) - result = module.infer(policy, params, self.model_, X_table) + result = self.infer(params, self.model_, X_table) - if ( - result_options == "compute_exact_objective_function" - ): # This is only set for score function - return result.objective_function_value * (-1) + if result_options == "compute_exact_objective_function": + # This is only set for score function + return -1 * result.objective_function_value else: return from_table(result.responses).ravel() - def _score(self, X, module, queue=None): + def _score(self, X): result_options = "compute_exact_objective_function" return self._predict( - X, self._get_backend("kmeans", "clustering", None), queue, result_options + X, + result_options, ) def _transform(self, X): @@ -427,9 +415,11 @@ def __init__( self.algorithm = algorithm assert self.algorithm == "lloyd" + @supports_queue def fit(self, X, y=None, queue=None): - return super()._fit(X, self._get_backend("kmeans", "clustering", None), queue) + return self._fit(X) + @supports_queue def predict(self, X, queue=None): """Predict the closest cluster each sample in X belongs to. @@ -447,7 +437,7 @@ def predict(self, X, queue=None): labels : ndarray of shape (n_samples,) Index of the cluster each sample belongs to. """ - return super()._predict(X, self._get_backend("kmeans", "clustering", None), queue) + return self._predict(X) def fit_predict(self, X, y=None, queue=None): """Compute cluster centers and predict cluster index for each sample. @@ -510,6 +500,7 @@ def transform(self, X): return self._transform(X) + @supports_queue def score(self, X, queue=None): """Opposite of the value of X on the K-means objective. @@ -523,7 +514,7 @@ def score(self, X, queue=None): score: float Opposite of the value of X on the K-means objective. """ - return super()._score(X, self._get_backend("kmeans", "clustering", None), queue) + return self._score(X) def k_means( diff --git a/onedal/cluster/kmeans_init.py b/onedal/cluster/kmeans_init.py index 4543fba003..58797ea70a 100755 --- a/onedal/cluster/kmeans_init.py +++ b/onedal/cluster/kmeans_init.py @@ -15,18 +15,18 @@ # ============================================================================== import numpy as np -from scipy.sparse import issparse from sklearn.utils import check_random_state -from daal4py.sklearn._utils import daal_check_version, get_dtype +from daal4py.sklearn._utils import daal_check_version +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend -from ..common._base import BaseEstimator as onedal_BaseEstimator from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array if daal_check_version((2023, "P", 200)): - class KMeansInit(onedal_BaseEstimator): + class KMeansInit: """ KMeansInit oneDAL implementation. """ @@ -48,6 +48,9 @@ def __init__( else: self.local_trials_count = local_trials_count + @bind_default_backend("kmeans_init.init", lookup_name="compute") + def backend_compute(self, params, X_table): ... + def _get_onedal_params(self, dtype=np.float32): return { "fptype": dtype, @@ -68,31 +71,24 @@ def _get_params_and_input(self, X, queue): params = self._get_onedal_params(X.dtype) return (params, X, X.dtype) - def _compute_raw(self, X_table, module, policy, dtype=np.float32): + def _compute_raw(self, X_table, dtype=np.float32): params = self._get_onedal_params(dtype) - - result = module.compute(policy, params, X_table) - + result = self.backend_compute(params, X_table) return result.centroids - def _compute(self, X, module, queue): - policy = self._get_policy(queue, X) - # oneDAL KMeans Init for sparse data does not have GPU support - if issparse(X): - policy = self._get_policy(None, None) - _, X_table, dtype = self._get_params_and_input(X, queue) - - centroids = self._compute_raw(X_table, module, policy, dtype) - + def _compute(self, X): + _, X_table, dtype = self._get_params_and_input( + X, queue=SyclQueueManager().get_global_queue() + ) + centroids = self._compute_raw(X_table, dtype) return from_table(centroids) - def compute_raw(self, X_table, policy, dtype=np.float32): - return self._compute_raw( - X_table, self._get_backend("kmeans_init", "init", None), policy, dtype - ) + def compute_raw(self, X_table, dtype=np.float32, queue=None): + return self._compute_raw(X_table, dtype) + @supports_queue def compute(self, X, queue=None): - return self._compute(X, self._get_backend("kmeans_init", "init", None), queue) + return self._compute(X) def kmeans_plusplus( X, @@ -107,6 +103,6 @@ def kmeans_plusplus( return ( KMeansInit( n_clusters, seed=random_seed, local_trials_count=n_local_trials - ).compute(X, queue), + ).compute(X, queue=queue), np.full(n_clusters, -1), ) diff --git a/onedal/common/_backend.py b/onedal/common/_backend.py new file mode 100644 index 0000000000..61e72c4eee --- /dev/null +++ b/onedal/common/_backend.py @@ -0,0 +1,238 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +from typing import Any, Callable, Literal, Optional + +from onedal import Backend, _default_backend, _spmd_backend +from onedal._device_offload import SyclQueueManager + +logger = logging.getLogger(__name__) + +# define types for backend functions: default, dpc, spmd +BackendType = Literal["none", "host", "dpc", "spmd"] + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +class BackendManager: + def __init__(self, backend_module): + self.backend = backend_module + + def get_backend_type(self) -> BackendType: + if self.backend is None: + return "none" + if self.backend.is_spmd: + return "spmd" + if self.backend.is_dpc: + return "dpc" + return "host" + + def get_backend_component(self, module_name: str, component_name: str): + """Get a component of the backend module. + + Args: + module(str): The module to get the component from. + component: The component to get from the module. + + Returns: + The component of the module. + """ + submodules = module_name.split(".") + module = getattr(self.backend, submodules[0]) + for submodule in submodules[1:]: + module = getattr(module, submodule) + + # component can be provided like submodule.method, there can be arbitrary number of submodules + # and methods + result = module + for part in component_name.split("."): + result = getattr(result, part) + + return result + + +default_manager = BackendManager(_default_backend) +spmd_manager = BackendManager(_spmd_backend) + + +class BackendFunction: + """Wrapper around backend function to allow setting auxiliary information""" + + def __init__( + self, + method: Callable[..., Any], + backend: Backend, + name: str, + no_policy: bool, + ): + self.method = method + self.name = name + self.backend = backend + self.no_policy = no_policy + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Dispatch to backend function with the appropriate policy which is determined from the global queue""" + if not args and not kwargs: + # immediate dispatching without any arguments, in particular no policy + return self.method() + + if self.no_policy: + return self.method(*args, **kwargs) + + # use globally configured queue (from `target_offload` configuration or provided data) + queue = SyclQueueManager.get_global_queue() + + if queue is not None and not (self.backend.is_dpc or self.backend.is_spmd): + raise RuntimeError("Operations using queues require the DPC/SPMD backend") + + if self.backend.is_spmd and queue is None: + raise RuntimeError("Executing functions from SPMD backend requires a queue") + + # craft the correct policy including the device queue + if queue is None: + policy = self.backend.host_policy() + elif self.backend.is_spmd: + policy = self.backend.spmd_data_parallel_policy(queue) + elif self.backend.is_dpc: + policy = self.backend.data_parallel_policy(queue) + else: + policy = self.backend.host_policy() + + logger.debug( + f"Dispatching function '{self.name}' with policy {policy} to {self.backend}" + ) + + # dispatch to backend function + return self.method(policy, *args, **kwargs) + + def __repr__(self) -> str: + return f"BackendFunction({self.backend}.{self.name})" + + +def __decorator( + method: Callable[..., Any], + backend_manager: BackendManager, + module_name: str, + lookup_name: Optional[str], + no_policy: bool, +) -> Callable[..., Any]: + """Decorator to bind a method to the specified backend""" + if lookup_name is None: + lookup_name = method.__name__ + + if backend_manager.get_backend_type() == "none": + raise RuntimeError("Internal __decorator() should not be called with no backend") + + backend_method = backend_manager.get_backend_component(module_name, lookup_name) + wrapped_method = BackendFunction( + backend_method, + backend_manager.backend, + name=f"{module_name}.{method.__name__}", + no_policy=no_policy, + ) + + backend_type = backend_manager.get_backend_type() + logger.debug( + f"Assigned method '<{backend_type}_backend>.{module_name}.{lookup_name}' to '{method.__qualname__}'" + ) + + return wrapped_method + + +def bind_default_backend( + module_name: str, lookup_name: Optional[str] = None, no_policy=False +): + """ + Decorator to bind a method from the default backend to a class. + + This decorator binds a method implementation from the default backend (host/dpc). + If the default backend is unavailable, the method is returned without modification. + + Parameters: + ---------- + module_name : str + The name of the module where the target function is located (e.g. `covariance`). + lookup_name : Optional[str], optional + The name of the method to look up in the backend module. If not provided, + the name of the decorated method is used. + no_policy : bool, optional + If True, the method will be decorated without a policy. Default is False. + + Returns: + ------- + Callable[..., Any] + The decorated method bound to the implementation in default backend, or the original + method if the default backend is unavailable. + """ + + def decorator(method: Callable[..., Any]): + # grab the lookup_name from outer scope + nonlocal lookup_name + + if _default_backend is None: + logger.debug( + f"Default backend unavailable, skipping decoration for '{method.__name__}'" + ) + return method + + return __decorator(method, default_manager, module_name, lookup_name, no_policy) + + return decorator + + +def bind_spmd_backend( + module_name: str, lookup_name: Optional[str] = None, no_policy=False +): + """ + Decorator to bind a method from the SPMD backend to a class. + + This decorator binds a method implementation from the SPMD backend. + If the SPMD backend is unavailable, the method is returned without modification. + + Parameters: + ---------- + module_name : str + The name of the module where the target function is located (e.g. `covariance`). + lookup_name : Optional[str], optional + The name of the method to look up in the backend module. If not provided, + the name of the decorated method is used. + no_policy : bool, optional + If True, the method will be decorated without a policy. Default is False. + + Returns: + ------- + Callable[..., Any] + The decorated method bound to the implementation in SPMD backend, or the original + method if the SPMD backend is unavailable. + """ + + def decorator(method: Callable[..., Any]): + # grab the lookup_name from outer scope + nonlocal lookup_name + + if _spmd_backend is None: + logger.debug( + f"SPMD backend unavailable, skipping decoration for '{method.__name__}'" + ) + return method + + return __decorator(method, spmd_manager, module_name, lookup_name, no_policy) + + return decorator diff --git a/onedal/common/_base.py b/onedal/common/_base.py deleted file mode 100644 index 3129b8d3cb..0000000000 --- a/onedal/common/_base.py +++ /dev/null @@ -1,38 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from abc import ABC - -from onedal import _backend - -from ._policy import _get_policy - - -def _get_backend(backend, module, submodule=None, method=None, *args, **kwargs): - result = getattr(backend, module) - if submodule: - result = getattr(result, submodule) - if method: - return getattr(result, method)(*args, **kwargs) - return result - - -class BaseEstimator(ABC): - def _get_backend(self, module, submodule=None, method=None, *args, **kwargs): - return _get_backend(_backend, module, submodule, method, *args, **kwargs) - - def _get_policy(self, queue, *data): - return _get_policy(queue, *data) diff --git a/onedal/common/_policy.py b/onedal/common/_policy.py deleted file mode 100644 index 0d7d8ca6a3..0000000000 --- a/onedal/common/_policy.py +++ /dev/null @@ -1,55 +0,0 @@ -# ============================================================================== -# Copyright 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import sys - -from onedal import _backend, _is_dpc_backend - - -def _get_policy(queue, *data): - data_queue = _get_queue(*data) - if _is_dpc_backend: - if queue is None: - if data_queue is None: - return _HostInteropPolicy() - return _DataParallelInteropPolicy(data_queue) - return _DataParallelInteropPolicy(queue) - else: - if not (data_queue is None and queue is None): - raise RuntimeError( - "Operation using the requested SYCL queue requires the DPC backend" - ) - return _HostInteropPolicy() - - -def _get_queue(*data): - if len(data) > 0 and hasattr(data[0], "__sycl_usm_array_interface__"): - # Assume that all data reside on the same device - return data[0].__sycl_usm_array_interface__["syclobj"] - return None - - -class _HostInteropPolicy(_backend.host_policy): - def __init__(self): - super().__init__() - - -if _is_dpc_backend: - - class _DataParallelInteropPolicy(_backend.data_parallel_policy): - def __init__(self, queue): - self._queue = queue - super().__init__(self._queue) diff --git a/onedal/common/_spmd_policy.py b/onedal/common/_spmd_policy.py deleted file mode 100644 index a9f83c8a47..0000000000 --- a/onedal/common/_spmd_policy.py +++ /dev/null @@ -1,30 +0,0 @@ -# ============================================================================== -# Copyright 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from onedal import _is_spmd_backend - -if _is_spmd_backend: - from onedal import _spmd_backend - - class _SPMDDataParallelInteropPolicy(_spmd_backend.spmd_data_parallel_policy): - def __init__(self, queue): - self._queue = queue - super().__init__(self._queue) - - def _get_spmd_policy(queue): - # TODO: - # cases when queue is None - return _SPMDDataParallelInteropPolicy(queue) diff --git a/onedal/common/hyperparameters.py b/onedal/common/hyperparameters.py index c32440cf62..9471999a57 100644 --- a/onedal/common/hyperparameters.py +++ b/onedal/common/hyperparameters.py @@ -19,7 +19,7 @@ from warnings import warn from daal4py.sklearn._utils import daal_check_version -from onedal import _backend +from onedal import _default_backend as backend if not daal_check_version((2024, "P", 0)): warn("Hyperparameters are supported in oneDAL starting from 2024.0.0 version.") @@ -98,11 +98,11 @@ def get_methods_with_prefix(obj, prefix): ( "linear_regression", "train", - ): _backend.linear_model.regression.train_hyperparameters(), - ("covariance", "compute"): _backend.covariance.compute_hyperparameters(), + ): backend.linear_model.regression.train_hyperparameters(), + ("covariance", "compute"): backend.covariance.compute_hyperparameters(), } if daal_check_version((2024, "P", 300)): - df_infer_hp = _backend.decision_forest.infer_hyperparameters + df_infer_hp = backend.decision_forest.infer_hyperparameters hyperparameters_backend[("decision_forest", "infer")] = df_infer_hp() hyperparameters_map = {} diff --git a/onedal/common/tests/test_policy.py b/onedal/common/tests/test_policy.py deleted file mode 100644 index 36d9865e23..0000000000 --- a/onedal/common/tests/test_policy.py +++ /dev/null @@ -1,76 +0,0 @@ -# ============================================================================== -# Copyright 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -from onedal.common._policy import _get_policy -from onedal.tests.utils._device_selection import ( - device_type_to_str, - get_memory_usm, - get_queues, - is_dpctl_device_available, -) -from onedal.utils._dpep_helpers import dpctl_available - - -@pytest.mark.parametrize("queue", get_queues()) -def test_queue_passed_directly(queue): - device_name = device_type_to_str(queue) - test_queue = _get_policy(queue) - test_device_name = test_queue.get_device_name() - assert test_device_name == device_name - - -@pytest.mark.parametrize("queue", get_queues()) -def test_with_numpy_data(queue): - X = np.zeros((5, 3)) - y = np.zeros(3) - - device_name = device_type_to_str(queue) - assert _get_policy(queue, X, y).get_device_name() == device_name - - -@pytest.mark.skipif(not dpctl_available, reason="depends on dpctl") -@pytest.mark.parametrize("queue", get_queues("cpu,gpu")) -@pytest.mark.parametrize("memtype", get_memory_usm()) -def test_with_usm_ndarray_data(queue, memtype): - if queue is None: - pytest.skip( - "dpctl Memory object with queue=None uses cached default (gpu if available)" - ) - - from dpctl.tensor import usm_ndarray - - device_name = device_type_to_str(queue) - X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=queue)) - y = usm_ndarray((3,), buffer=memtype(3 * 8, queue=queue)) - assert _get_policy(None, X, y).get_device_name() == device_name - - -@pytest.mark.skipif( - not is_dpctl_device_available(["cpu", "gpu"]), reason="test uses multiple devices" -) -@pytest.mark.parametrize("memtype", get_memory_usm()) -def test_queue_parameter_with_usm_ndarray(memtype): - from dpctl import SyclQueue - from dpctl.tensor import usm_ndarray - - q1 = SyclQueue("cpu") - q2 = SyclQueue("gpu") - - X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=q1)) - assert _get_policy(q2, X).get_device_name() == device_type_to_str(q2) diff --git a/onedal/common/tests/test_sycl.py b/onedal/common/tests/test_sycl.py index dc892663b6..b98060b21c 100644 --- a/onedal/common/tests/test_sycl.py +++ b/onedal/common/tests/test_sycl.py @@ -14,16 +14,15 @@ # limitations under the License. # ============================================================================== -import numpy as np import pytest -from onedal import _backend, _is_dpc_backend +from onedal import _default_backend as backend from onedal.tests.utils._device_selection import get_queues from onedal.utils._dpep_helpers import dpctl_available @pytest.mark.skipif( - not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl" + not backend.is_dpc or not dpctl_available, reason="requires dpc backend and dpctl" ) @pytest.mark.parametrize("device_type", ["cpu", "gpu"]) @pytest.mark.parametrize("device_number", [None, 0, 1, 2, 3]) @@ -32,7 +31,7 @@ def test_sycl_queue_string_creation(device_type, device_number): from dpctl import SyclQueue from dpctl._sycl_queue import SyclQueueCreationError - onedal_SyclQueue = _backend.SyclQueue + onedal_SyclQueue = backend.SyclQueue device = ( ":".join([device_type, str(device_number)]) @@ -63,14 +62,14 @@ def test_sycl_queue_string_creation(device_type, device_number): @pytest.mark.skipif( - not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl" + not backend.is_dpc or not dpctl_available, reason="requires dpc backend and dpctl" ) @pytest.mark.parametrize("queue", get_queues()) def test_sycl_queue_conversion(queue): if queue is None: pytest.skip("Not a dpctl queue") SyclQueue = queue.__class__ - onedal_SyclQueue = _backend.SyclQueue + onedal_SyclQueue = backend.SyclQueue q = onedal_SyclQueue(queue) @@ -83,7 +82,7 @@ def test_sycl_queue_conversion(queue): @pytest.mark.skipif( - not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl" + not backend.is_dpc or not dpctl_available, reason="requires dpc backend and dpctl" ) @pytest.mark.parametrize("queue", get_queues()) def test_sycl_device_attributes(queue): @@ -91,7 +90,7 @@ def test_sycl_device_attributes(queue): if queue is None: pytest.skip("Not a dpctl queue") - onedal_SyclQueue = _backend.SyclQueue + onedal_SyclQueue = backend.SyclQueue onedal_queue = onedal_SyclQueue(queue) @@ -107,17 +106,17 @@ def test_sycl_device_attributes(queue): assert onedal_queue.sycl_device.filter_string in queue.sycl_device.filter_string -@pytest.mark.skipif(not _is_dpc_backend, reason="requires dpc backend") +@pytest.mark.skipif(not backend.is_dpc, reason="requires dpc backend") def test_backend_queue(): try: - q = _backend.SyclQueue("cpu") + q = backend.SyclQueue("cpu") except RuntimeError: pytest.skip("OpenCL CPU runtime not installed") # verify copying via a py capsule object is functional - q2 = _backend.SyclQueue(q._get_capsule()) + q2 = backend.SyclQueue(q._get_capsule()) # verify copying via the _get_capsule attribute - q3 = _backend.SyclQueue(q) + q3 = backend.SyclQueue(q) q_array = [q, q2, q3] diff --git a/onedal/covariance/covariance.py b/onedal/covariance/covariance.py index 795df08dd9..bc9445c3ef 100644 --- a/onedal/covariance/covariance.py +++ b/onedal/covariance/covariance.py @@ -18,19 +18,23 @@ import numpy as np from daal4py.sklearn._utils import daal_check_version, get_dtype -from onedal.utils import _check_array +from onedal._device_offload import supports_queue +from onedal.common._backend import bind_default_backend +from onedal.utils.validation import _check_array -from ..common._base import BaseEstimator from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table -class BaseEmpiricalCovariance(BaseEstimator, metaclass=ABCMeta): +class BaseEmpiricalCovariance(metaclass=ABCMeta): def __init__(self, method="dense", bias=False, assume_centered=False): self.method = method self.bias = bias self.assume_centered = assume_centered + @bind_default_backend("covariance") + def compute(self, *args, **kwargs): ... + def _get_onedal_params(self, dtype=np.float32): params = { "fptype": dtype, @@ -73,6 +77,7 @@ class EmpiricalCovariance(BaseEmpiricalCovariance): Estimated covariance matrix """ + @supports_queue def fit(self, X, y=None, queue=None): """Fit the sample covariance matrix of X. @@ -93,23 +98,14 @@ def fit(self, X, y=None, queue=None): self : object Returns the instance itself. """ - policy = self._get_policy(queue, X) X = _check_array(X, dtype=[np.float64, np.float32]) X = to_table(X, queue=queue) params = self._get_onedal_params(X.dtype) hparams = get_hyperparameters("covariance", "compute") if hparams is not None and not hparams.is_default: - result = self._get_backend( - "covariance", - None, - "compute", - policy, - params, - hparams.backend, - X, - ) + result = self.compute(params, hparams.backend, X) else: - result = self._get_backend("covariance", None, "compute", policy, params, X) + result = self.compute(params, X) if daal_check_version((2024, "P", 1)) or (not self.bias): self.covariance_ = from_table(result.cov_matrix) else: diff --git a/onedal/covariance/incremental_covariance.py b/onedal/covariance/incremental_covariance.py index b0bfb04e22..2fd189e8d6 100644 --- a/onedal/covariance/incremental_covariance.py +++ b/onedal/covariance/incremental_covariance.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== + import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype +from daal4py.sklearn._utils import daal_check_version +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array from .covariance import BaseEmpiricalCovariance @@ -55,12 +58,21 @@ class IncrementalEmpiricalCovariance(BaseEmpiricalCovariance): def __init__(self, method="dense", bias=False, assume_centered=False): super().__init__(method, bias, assume_centered) self._reset() + self._queue = None + + @bind_default_backend("covariance") + def partial_compute(self, params, partial_result, X_table): ... + + @bind_default_backend("covariance") + def partial_compute_result(self): ... + + @bind_default_backend("covariance") + def finalize_compute(self, params, partial_result): ... def _reset(self): self._need_to_finalize = False - self._partial_result = self._get_backend( - "covariance", None, "partial_compute_result" - ) + self._queue = None + self._partial_result = self.partial_compute_result() def __getstate__(self): # Since finalize_fit can't be dispatched without directly provided queue @@ -73,6 +85,7 @@ def __getstate__(self): return data + @supports_queue def partial_fit(self, X, y=None, queue=None): """ Computes partial data for the covariance matrix @@ -98,27 +111,19 @@ def partial_fit(self, X, y=None, queue=None): X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) self._queue = queue - - policy = self._get_policy(queue, X) - X_table = to_table(X, queue=queue) if not hasattr(self, "_dtype"): self._dtype = X_table.dtype params = self._get_onedal_params(self._dtype) - self._partial_result = self._get_backend( - "covariance", - None, - "partial_compute", - policy, - params, - self._partial_result, - X_table, - ) + table_X = to_table(X) + self._partial_result = self.partial_compute(params, self._partial_result, table_X) self._need_to_finalize = True + # store the queue for when we finalize + self._queue = queue - def finalize_fit(self, queue=None): + def finalize_fit(self): """ Finalizes covariance matrix and obtains `covariance_` and `location_` attributes from the current `_partial_result`. @@ -135,19 +140,9 @@ def finalize_fit(self, queue=None): """ if self._need_to_finalize: params = self._get_onedal_params(self._dtype) - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - - result = self._get_backend( - "covariance", - None, - "finalize_compute", - policy, - params, - self._partial_result, - ) + with SyclQueueManager.manage_global_queue(self._queue): + result = self.finalize_compute(params, self._partial_result) + if daal_check_version((2024, "P", 1)) or (not self.bias): self.covariance_ = from_table(result.cov_matrix) else: diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index 5e90ae5d35..6c2e10cd9c 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -14,11 +14,9 @@ # limitations under the License. # ============================================================================== -import warnings - import numpy as np -from onedal import _backend, _is_dpc_backend +from onedal import _default_backend as backend def _apply_and_pass(func, *args, **kwargs): @@ -29,7 +27,7 @@ def _apply_and_pass(func, *args, **kwargs): def _convert_one_to_table(arg, queue=None): # All inputs for table conversion must be array-like or sparse, not scalars - return _backend.to_table(np.atleast_2d(arg) if np.isscalar(arg) else arg, queue) + return backend.to_table(np.atleast_2d(arg) if np.isscalar(arg) else arg, queue) def to_table(*args, queue=None): @@ -54,7 +52,7 @@ def to_table(*args, queue=None): return _apply_and_pass(_convert_one_to_table, *args, queue=queue) -if _is_dpc_backend: +if backend.is_dpc: try: # try/catch is used here instead of dpep_helpers because @@ -79,8 +77,6 @@ def _table_to_array(table, xp=None): def _table_to_array(table, xp=None): return xp.asarray(table) - from ..common._policy import _HostInteropPolicy - def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): # Currently only `__sycl_usm_array_interface__` protocol used to # convert into dpnp/dpctl tensors. @@ -96,12 +92,12 @@ def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): # Host tables first converted into numpy.narrays and then to array from xp # namespace. return xp.asarray( - _backend.from_table(table), usm_type="device", sycl_queue=sycl_queue + backend.from_table(table), usm_type="device", sycl_queue=sycl_queue ) else: return _table_to_array(table, xp=xp) - return _backend.from_table(table) + return backend.from_table(table) else: @@ -112,7 +108,7 @@ def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): raise RuntimeError( "SYCL usm array conversion from table requires the DPC backend" ) - return _backend.from_table(table) + return backend.from_table(table) def from_table(*args, sycl_queue=None, sua_iface=None, xp=None): diff --git a/onedal/datatypes/tests/test_data.py b/onedal/datatypes/tests/test_data.py index 21ca7151fe..55426a889e 100644 --- a/onedal/datatypes/tests/test_data.py +++ b/onedal/datatypes/tests/test_data.py @@ -19,10 +19,12 @@ import scipy.sparse as sp from numpy.testing import assert_allclose -from onedal import _backend, _is_dpc_backend +from onedal import _default_backend, _dpc_backend from onedal.datatypes import from_table, to_table from onedal.utils._dpep_helpers import dpctl_available +backend = _dpc_backend or _default_backend + if dpctl_available: from onedal.datatypes.tests.common import ( _assert_sua_iface_fields, @@ -53,29 +55,25 @@ ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray} -if _is_dpc_backend: +if backend.is_dpc: from daal4py.sklearn._utils import get_dtype - from onedal.cluster.dbscan import BaseDBSCAN - from onedal.common._policy import _get_policy + from onedal.cluster.dbscan import DBSCAN class DummyEstimatorWithTableConversions: def fit(self, X, y=None): sua_iface, xp, _ = _get_sycl_namespace(X) - policy = _get_policy(X.sycl_queue, None) - bs_DBSCAN = BaseDBSCAN() + dbscan = DBSCAN() types = [xp.float32, xp.float64] if get_dtype(X) not in types: X = xp.astype(X, dtype=xp.float64) dtype = get_dtype(X) - params = bs_DBSCAN._get_onedal_params(dtype) + params = dbscan._get_onedal_params(dtype) X_table = to_table(X) # TODO: # check other candidates for the dummy base oneDAL func. # oneDAL backend func is needed to check result table checks. - result = _backend.dbscan.clustering.compute( - policy, params, X_table, to_table(None) - ) + result = dbscan.compute(params, X_table, to_table(None)) result_responses_table = result.responses result_responses_df = from_table( result_responses_table, @@ -231,7 +229,7 @@ def test_conversion_to_table(dtype): reason="dpctl is required for checks.", ) @pytest.mark.skipif( - not _is_dpc_backend, + not backend.is_dpc, reason="__sycl_usm_array_interface__ support requires DPC backend.", ) @pytest.mark.parametrize( @@ -267,7 +265,7 @@ def test_input_sua_iface_zero_copy(dataframe, queue, order, dtype): reason="dpctl is required for checks.", ) @pytest.mark.skipif( - not _is_dpc_backend, + not backend.is_dpc, reason="__sycl_usm_array_interface__ support requires DPC backend.", ) @pytest.mark.parametrize( @@ -324,7 +322,7 @@ def test_table_conversions(dataframe, queue, order, data_shape, dtype): @pytest.mark.skipif( - not _is_dpc_backend, + not backend.is_dpc, reason="__sycl_usm_array_interface__ support requires DPC backend.", ) @pytest.mark.parametrize( @@ -344,7 +342,7 @@ def test_sua_iface_interop_invalid_shape(dataframe, queue, data_shape): @pytest.mark.skipif( - not _is_dpc_backend, + not backend.is_dpc, reason="__sycl_usm_array_interface__ support requires DPC backend.", ) @pytest.mark.parametrize( @@ -376,7 +374,7 @@ def test_sua_iface_interop_unsupported_dtypes(dataframe, queue, dtype): "dataframe,queue", get_dataframes_and_queues("numpy,dpctl,dpnp", "cpu,gpu") ) def test_to_table_non_contiguous_input(dataframe, queue): - if dataframe in "dpnp,dpctl" and not _is_dpc_backend: + if dataframe in "dpnp,dpctl" and not backend.is_dpc: pytest.skip("__sycl_usm_array_interface__ support requires DPC backend.") X, _ = np.mgrid[:10, :10] X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) @@ -389,7 +387,7 @@ def test_to_table_non_contiguous_input(dataframe, queue): @pytest.mark.skipif( - _is_dpc_backend, + backend.is_dpc, reason="Required check should be done if no DPC backend.", ) @pytest.mark.parametrize( @@ -407,7 +405,7 @@ def test_sua_iface_interop_if_no_dpc_backend(dataframe, queue, dtype): @pytest.mark.skipif( - not _is_dpc_backend, reason="Requires DPC backend for dtype conversion" + not backend.is_dpc, reason="Requires DPC backend for dtype conversion" ) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("sparse", [True, False]) diff --git a/onedal/datatypes/utils/sua_iface_helpers.cpp b/onedal/datatypes/utils/sua_iface_helpers.cpp index d345a35645..931f0f4362 100644 --- a/onedal/datatypes/utils/sua_iface_helpers.cpp +++ b/onedal/datatypes/utils/sua_iface_helpers.cpp @@ -167,7 +167,7 @@ dal::data_layout get_sua_iface_layout(const py::dict& sua_dict, } } else { - throw std::runtime_error("Unsupporterd data shape.`"); + throw std::runtime_error("Unsupported data shape.`"); } } diff --git a/onedal/decomposition/incremental_pca.py b/onedal/decomposition/incremental_pca.py index 58c852ed81..0fa5be5f95 100644 --- a/onedal/decomposition/incremental_pca.py +++ b/onedal/decomposition/incremental_pca.py @@ -16,10 +16,11 @@ import numpy as np -from daal4py.sklearn._utils import get_dtype +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array from .pca import BasePCA @@ -96,14 +97,24 @@ def __init__( self.method = method self.is_deterministic = is_deterministic self.whiten = whiten + self._queue = None self._reset() + @bind_default_backend("decomposition.dim_reduction") + def finalize_train(self, params, partial_result): ... + + @bind_default_backend("decomposition.dim_reduction") + def partial_train(self, params, partial_result, X_table): ... + + @bind_default_backend("decomposition.dim_reduction") + def partial_train_result(self): ... + def _reset(self): self._need_to_finalize = False - module = self._get_backend("decomposition", "dim_reduction") + self._queue = None + self._partial_result = self.partial_train_result() if hasattr(self, "components_"): del self.components_ - self._partial_result = module.partial_train_result() def __getstate__(self): # Since finalize_fit can't be dispatched without directly provided queue @@ -116,7 +127,8 @@ def __getstate__(self): return data - def partial_fit(self, X, queue): + @supports_queue + def partial_fit(self, X, queue=None): """Incremental fit with X. All of X is processed as a single batch. Parameters @@ -153,27 +165,21 @@ def partial_fit(self, X, queue): self.n_components_ = self.n_components self._queue = queue - - policy = self._get_policy(queue, X) X_table = to_table(X, queue=queue) if not hasattr(self, "_dtype"): self._dtype = X_table.dtype self._params = self._get_onedal_params(X_table) - self._partial_result = self._get_backend( - "decomposition", - "dim_reduction", - "partial_train", - policy, - self._params, - self._partial_result, - X_table, + X_table = to_table(X) + self._partial_result = self.partial_train( + self._params, self._partial_result, X_table ) self._need_to_finalize = True + self._queue = queue return self - def finalize_fit(self, queue=None): + def finalize_fit(self): """ Finalizes principal components computation and obtains resulting attributes from the current `_partial_result`. @@ -189,12 +195,8 @@ def finalize_fit(self, queue=None): Returns the instance itself. """ if self._need_to_finalize: - module = self._get_backend("decomposition", "dim_reduction") - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - result = module.finalize_train(policy, self._params, self._partial_result) + with SyclQueueManager.manage_global_queue(self._queue): + result = self.finalize_train(self._params, self._partial_result) self.mean_ = from_table(result.means).ravel() self.var_ = from_table(result.variances).ravel() self.components_ = from_table(result.eigenvectors) @@ -210,5 +212,7 @@ def finalize_fit(self, queue=None): self.noise_variance_ = self._compute_noise_variance( self.n_components_, min(self.n_samples_seen_, self.n_features_in_) ) - self._need_to_finalize = False + self._need_to_finalize = False + self._queue = None + return self diff --git a/onedal/decomposition/pca.py b/onedal/decomposition/pca.py index fe6f585ba5..c3769d90e3 100644 --- a/onedal/decomposition/pca.py +++ b/onedal/decomposition/pca.py @@ -21,11 +21,13 @@ from sklearn.decomposition._pca import _infer_dimension from sklearn.utils.extmath import stable_cumsum -from ..common._base import BaseEstimator +from onedal._device_offload import supports_queue +from onedal.common._backend import bind_default_backend + from ..datatypes import from_table, to_table -class BasePCA(BaseEstimator, metaclass=ABCMeta): +class BasePCA(metaclass=ABCMeta): """ Base class for PCA oneDAL implementation. """ @@ -42,6 +44,16 @@ def __init__( self.is_deterministic = is_deterministic self.whiten = whiten + # provides direct access to the backend model constructor + @bind_default_backend("decomposition.dim_reduction") + def model(self): ... + + @bind_default_backend("decomposition.dim_reduction") + def train(self, params, X): ... + + @bind_default_backend("decomposition.dim_reduction") + def infer(self, params, X, model): ... + def _get_onedal_params(self, data, stage=None): if stage is None: n_components = self._resolve_n_components_for_training(data.shape) @@ -119,7 +131,7 @@ def _compute_noise_variance(self, n_components, n_sf_min): return 0.0 def _create_model(self): - m = self._get_backend("decomposition", "dim_reduction", "model") + m = self.model() m.eigenvectors = to_table(self.components_) m.means = to_table(self.mean_) if self.whiten: @@ -127,26 +139,23 @@ def _create_model(self): self._onedal_model = m return m + @supports_queue def predict(self, X, queue=None): - policy = self._get_policy(queue, X) model = self._create_model() X_table = to_table(X, queue=queue) params = self._get_onedal_params(X_table, stage="predict") - - result = self._get_backend( - "decomposition", "dim_reduction", "infer", policy, params, model, X_table - ) + result = self.infer(params, model, to_table(X)) return from_table(result.transformed_data) class PCA(BasePCA): + @supports_queue def fit(self, X, y=None, queue=None): n_samples, n_features = X.shape n_sf_min = min(n_samples, n_features) self._validate_n_components(self.n_components, n_samples, n_features) - policy = self._get_policy(queue, X) # TODO: investigate why np.ndarray with OWNDATA=FALSE flag # fails to be converted to oneDAL table if isinstance(X, np.ndarray) and not X.flags["OWNDATA"]: @@ -154,9 +163,7 @@ def fit(self, X, y=None, queue=None): X = to_table(X, queue=queue) params = self._get_onedal_params(X) - result = self._get_backend( - "decomposition", "dim_reduction", "train", policy, params, X - ) + result = self.train(params, X) self.mean_ = from_table(result.means).ravel() self.variances_ = from_table(result.variances) @@ -169,10 +176,6 @@ def fit(self, X, y=None, queue=None): self.n_samples_ = n_samples self.n_features_ = n_features - U = None - S = self.singular_values_ - Vt = self.components_ - n_components = self._resolve_n_components_for_result(X.shape) self.n_components_ = n_components self.noise_variance_ = self._compute_noise_variance(n_components, n_sf_min) diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index 0a006bf9b1..87167133a2 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -24,13 +24,14 @@ from sklearn.utils import check_random_state from daal4py.sklearn._utils import daal_check_version +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from sklearnex import get_hyperparameters -from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin, RegressorMixin from ..datatypes import from_table, to_table -from ..utils import ( +from ..utils.validation import ( _check_array, _check_n_features, _check_X_y, @@ -39,7 +40,7 @@ ) -class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta): +class BaseForest(BaseEnsemble, metaclass=ABCMeta): @abstractmethod def __init__( self, @@ -96,6 +97,12 @@ def __init__( self.variable_importance_mode = variable_importance_mode self.algorithm = algorithm + @abstractmethod + def train(self, *args, **kwargs): ... + + @abstractmethod + def infer(self, *args, **kwargs): ... + def _to_absolute_max_features(self, n_features): if self.max_features is None: return n_features @@ -288,7 +295,7 @@ def _get_sample_weight(self, sample_weight, X): return sample_weight - def _fit(self, X, y, sample_weight, module, queue): + def _fit(self, X, y, sample_weight): X, y = _check_X_y( X, y, @@ -305,10 +312,9 @@ def _fit(self, X, y, sample_weight, module, queue): data = (X, y, sample_weight) else: data = (X, y) - policy = self._get_policy(queue, *data) - data = to_table(*data, queue=queue) + data = to_table(*data, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(data[0]) - train_result = module.train(policy, params, *data) + train_result = self.train(params, *data) self._onedal_model = train_result.model @@ -345,41 +351,39 @@ def _create_model(self, module): # upate error msg. raise NotImplementedError("Creating model is not supported.") - def _predict(self, X, module, queue, hparams=None): + def _predict(self, X, hparams=None): _check_is_fitted(self) X = _check_array( X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False ) _check_n_features(self, X, False) - policy = self._get_policy(queue, X) model = self._onedal_model - X = to_table(X, queue=queue) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(X) if hparams is not None and not hparams.is_default: - result = module.infer(policy, params, hparams.backend, model, X) + result = self.infer(params, hparams.backend, model, X) else: - result = module.infer(policy, params, model, X) + result = self.infer(params, model, X) y = from_table(result.responses) return y - def _predict_proba(self, X, module, queue, hparams=None): + def _predict_proba(self, X, hparams=None): _check_is_fitted(self) X = _check_array( X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False ) _check_n_features(self, X, False) - policy = self._get_policy(queue, X) - X = to_table(X, queue=queue) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(X) params["infer_mode"] = "class_probabilities" model = self._onedal_model if hparams is not None and not hparams.is_default: - result = module.infer(policy, params, hparams.backend, model, X) + result = self.infer(params, hparams.backend, model, X) else: - result = module.infer(policy, params, model, X) + result = self.infer(params, model, X) y = from_table(result.probabilities) return y @@ -443,6 +447,12 @@ def __init__( algorithm=algorithm, ) + @bind_default_backend("decision_forest.classification") + def train(self, *args, **kwargs): ... + + @bind_default_backend("decision_forest.classification") + def infer(self, *args, **kwargs): ... + def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( y, self.class_weight, dtype @@ -455,35 +465,22 @@ def _validate_targets(self, y, dtype): # self.n_classes_ = self.classes_ return y + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return self._fit( - X, - y, - sample_weight, - self._get_backend("decision_forest", "classification", None), - queue, - ) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): hparams = get_hyperparameters("decision_forest", "infer") - pred = super()._predict( - X, - self._get_backend("decision_forest", "classification", None), - queue, - hparams, - ) + pred = self._predict(X, hparams) return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) + @supports_queue def predict_proba(self, X, queue=None): hparams = get_hyperparameters("decision_forest", "infer") - return super()._predict_proba( - X, - self._get_backend("decision_forest", "classification", None), - queue, - hparams, - ) + return super()._predict_proba(X, hparams) class RandomForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta): @@ -544,25 +541,23 @@ def __init__( algorithm=algorithm, ) + @bind_default_backend("decision_forest.regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("decision_forest.regression") + def infer(self, *args, **kwargs): ... + + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): if sample_weight is not None: if hasattr(sample_weight, "__array__"): sample_weight[sample_weight == 0.0] = 1.0 sample_weight = [sample_weight] - return super()._fit( - X, - y, - sample_weight, - self._get_backend("decision_forest", "regression", None), - queue, - ) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - return ( - super() - ._predict(X, self._get_backend("decision_forest", "regression", None), queue) - .ravel() - ) + return self._predict(X).ravel() class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta): @@ -623,6 +618,12 @@ def __init__( algorithm=algorithm, ) + @bind_default_backend("decision_forest.classification") + def train(self, *args, **kwargs): ... + + @bind_default_backend("decision_forest.classification") + def infer(self, *args, **kwargs): ... + def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( y, self.class_weight, dtype @@ -635,26 +636,19 @@ def _validate_targets(self, y, dtype): # self.n_classes_ = self.classes_ return y + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return self._fit( - X, - y, - sample_weight, - self._get_backend("decision_forest", "classification", None), - queue, - ) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - pred = super()._predict( - X, self._get_backend("decision_forest", "classification", None), queue - ) + pred = self._predict(X) return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) + @supports_queue def predict_proba(self, X, queue=None): - return super()._predict_proba( - X, self._get_backend("decision_forest", "classification", None), queue - ) + return super()._predict_proba(X) class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta): @@ -715,22 +709,20 @@ def __init__( algorithm=algorithm, ) + @bind_default_backend("decision_forest.regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("decision_forest.regression") + def infer(self, *args, **kwargs): ... + + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): if sample_weight is not None: if hasattr(sample_weight, "__array__"): sample_weight[sample_weight == 0.0] = 1.0 sample_weight = [sample_weight] - return super()._fit( - X, - y, - sample_weight, - self._get_backend("decision_forest", "regression", None), - queue, - ) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - return ( - super() - ._predict(X, self._get_backend("decision_forest", "regression", None), queue) - .ravel() - ) + return self._predict(X).ravel() diff --git a/onedal/linear_model/incremental_linear_model.py b/onedal/linear_model/incremental_linear_model.py index bc48d59077..a52992e8c6 100644 --- a/onedal/linear_model/incremental_linear_model.py +++ b/onedal/linear_model/incremental_linear_model.py @@ -17,10 +17,12 @@ import numpy as np from daal4py.sklearn._utils import get_dtype +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table -from ..utils import _check_X_y, _num_features +from ..utils.validation import _check_X_y, _num_features from .linear_model import BaseLinearRegression @@ -44,13 +46,24 @@ class IncrementalLinearRegression(BaseLinearRegression): def __init__(self, fit_intercept=True, copy_X=False, algorithm="norm_eq"): super().__init__(fit_intercept=fit_intercept, copy_X=copy_X, algorithm=algorithm) + self._queue = None self._reset() + @bind_default_backend("linear_model.regression") + def partial_train_result(self): ... + + @bind_default_backend("linear_model.regression") + def partial_train(self, *args, **kwargs): ... + + @bind_default_backend("linear_model.regression") + def finalize_train(self, *args, **kwargs): ... + def _reset(self): - self._partial_result = self._get_backend( - "linear_model", "regression", "partial_train_result" - ) + # Get the pointer to partial_result from backend + self._queue = None + self._partial_result = self.partial_train_result() + @supports_queue def partial_fit(self, X, y, queue=None): """ Computes partial data for linear regression @@ -72,11 +85,7 @@ def partial_fit(self, X, y, queue=None): self : object Returns the instance itself. """ - module = self._get_backend("linear_model", "regression") - self._queue = queue - policy = self._get_policy(queue, X) - X, y = _check_X_y( X, y, dtype=[np.float64, np.float32], accept_2d_y=True, force_all_finite=False ) @@ -92,20 +101,16 @@ def partial_fit(self, X, y, queue=None): hparams = get_hyperparameters("linear_regression", "train") if hparams is not None and not hparams.is_default: - self._partial_result = module.partial_train( - policy, - self._params, - hparams.backend, - self._partial_result, - X_table, - y_table, + self._partial_result = self.partial_train( + self._params, hparams.backend, self._partial_result, X_table, y_table ) else: - self._partial_result = module.partial_train( - policy, self._params, self._partial_result, X_table, y_table + self._partial_result = self.partial_train( + self._params, self._partial_result, X_table, y_table ) + self._queue = queue - def finalize_fit(self, queue=None): + def finalize_fit(self): """ Finalizes linear regression computation and obtains coefficients from the current `_partial_result`. @@ -121,19 +126,14 @@ def finalize_fit(self, queue=None): Returns the instance itself. """ - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - - module = self._get_backend("linear_model", "regression") hparams = get_hyperparameters("linear_regression", "train") - if hparams is not None and not hparams.is_default: - result = module.finalize_train( - policy, self._params, hparams.backend, self._partial_result - ) - else: - result = module.finalize_train(policy, self._params, self._partial_result) + with SyclQueueManager.manage_global_queue(self._queue): + if hparams is not None and not hparams.is_default: + result = self.finalize_train( + self._params, hparams.backend, self._partial_result + ) + else: + result = self.finalize_train(self._params, self._partial_result) self._onedal_model = result.model @@ -143,6 +143,8 @@ def finalize_fit(self, queue=None): packed_coefficients[:, 0].squeeze(), ) + self._queue = None + return self @@ -170,16 +172,26 @@ class IncrementalRidge(BaseLinearRegression): """ def __init__(self, alpha=1.0, fit_intercept=True, copy_X=False, algorithm="norm_eq"): - module = self._get_backend("linear_model", "regression") super().__init__( fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm ) - self._partial_result = module.partial_train_result() + self._queue = None + self._reset() def _reset(self): - module = self._get_backend("linear_model", "regression") - self._partial_result = module.partial_train_result() + self._queue = None + self._partial_result = self.partial_train_result() + @bind_default_backend("linear_model.regression") + def partial_train_result(self): ... + + @bind_default_backend("linear_model.regression") + def partial_train(self, *args, **kwargs): ... + + @bind_default_backend("linear_model.regression") + def finalize_train(self, *args, **kwargs): ... + + @supports_queue def partial_fit(self, X, y, queue=None): """ Computes partial data for ridge regression @@ -201,11 +213,7 @@ def partial_fit(self, X, y, queue=None): self : object Returns the instance itself. """ - module = self._get_backend("linear_model", "regression") - self._queue = queue - policy = self._get_policy(queue, X) - X, y = _check_X_y( X, y, dtype=[np.float64, np.float32], accept_2d_y=True, force_all_finite=False ) @@ -219,11 +227,17 @@ def partial_fit(self, X, y, queue=None): self._dtype = X_table.dtype self._params = self._get_onedal_params(self._dtype) - self._partial_result = module.partial_train( - policy, self._params, self._partial_result, X_table, y_table - ) + hparams = get_hyperparameters("linear_regression", "train") + if hparams is not None and not hparams.is_default: + self._partial_result = self.partial_train( + self._params, hparams.backend, self._partial_result, X_table, y_table + ) + else: + self._partial_result = self.partial_train( + self._params, self._partial_result, X_table, y_table + ) - def finalize_fit(self, queue=None): + def finalize_fit(self): """ Finalizes ridge regression computation and obtains coefficients from the current `_partial_result`. @@ -238,12 +252,8 @@ def finalize_fit(self, queue=None): self : object Returns the instance itself. """ - module = self._get_backend("linear_model", "regression") - if queue is not None: - policy = self._get_policy(queue) - else: - policy = self._get_policy(self._queue) - result = module.finalize_train(policy, self._params, self._partial_result) + with SyclQueueManager.manage_global_queue(self._queue): + result = self.finalize_train(self._params, self._partial_result) self._onedal_model = result.model @@ -253,4 +263,6 @@ def finalize_fit(self, queue=None): packed_coefficients[:, 0].squeeze(), ) + self._queue = None + return self diff --git a/onedal/linear_model/linear_model.py b/onedal/linear_model/linear_model.py index 264a571de0..41b7114ae8 100755 --- a/onedal/linear_model/linear_model.py +++ b/onedal/linear_model/linear_model.py @@ -20,15 +20,16 @@ import numpy as np from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend -from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common.hyperparameters import get_hyperparameters from ..datatypes import from_table, to_table -from ..utils import _check_array, _check_n_features, _check_X_y, _num_features +from ..utils.validation import _check_array, _check_n_features, _check_X_y, _num_features -class BaseLinearRegression(BaseEstimator, metaclass=ABCMeta): +class BaseLinearRegression(metaclass=ABCMeta): """ Base class for LinearRegression oneDAL implementation. """ @@ -40,6 +41,16 @@ def __init__(self, fit_intercept, copy_X, algorithm, alpha=0.0): self.copy_X = copy_X self.algorithm = algorithm + @bind_default_backend("linear_model.regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("linear_model.regression") + def infer(self, params, model, X): ... + + # direct access to the backend model class + @bind_default_backend("linear_model.regression") + def model(self): ... + def _get_onedal_params(self, dtype=np.float32): intercept = "intercept|" if self.fit_intercept else "" params = { @@ -53,9 +64,8 @@ def _get_onedal_params(self, dtype=np.float32): return params - def _create_model(self, policy): - module = self._get_backend("linear_model", "regression") - model = module.model() + def _create_model(self): + model = self.model() coefficients = self.coef_ dtype = get_dtype(coefficients) @@ -92,13 +102,14 @@ def _create_model(self, policy): packed_coefficients[:, 0][:, np.newaxis] = intercept model.packed_coefficients = to_table( - packed_coefficients, queue=getattr(policy, "_queue", None) + packed_coefficients, queue=SyclQueueManager.get_global_queue() ) self._onedal_model = model return model + @supports_queue def predict(self, X, queue=None): """ Predict using the linear model. @@ -115,12 +126,9 @@ def predict(self, X, queue=None): C : array, shape (n_samples, n_targets) Returns predicted values. """ - module = self._get_backend("linear_model", "regression") _check_is_fitted(self) - policy = self._get_policy(queue, X) - X = _check_array( X, dtype=[np.float64, np.float32], force_all_finite=False, ensure_2d=False ) @@ -129,12 +137,12 @@ def predict(self, X, queue=None): if hasattr(self, "_onedal_model"): model = self._onedal_model else: - model = self._create_model(policy) + model = self._create_model() X_table = to_table(X, queue=queue) params = self._get_onedal_params(X_table.dtype) - result = module.infer(policy, params, model, X_table) + result = self.infer(params, model, X_table) y = from_table(result.responses) if y.shape[1] == 1 and self.coef_.ndim == 1: @@ -171,6 +179,7 @@ def __init__( ): super().__init__(fit_intercept=fit_intercept, copy_X=copy_X, algorithm=algorithm) + @supports_queue def fit(self, X, y, queue=None): """ Fit linear model. @@ -190,7 +199,6 @@ def fit(self, X, y, queue=None): self : object Fitted Estimator. """ - module = self._get_backend("linear_model", "regression") # TODO Fix _check_X_y to make sure this conversion is there if not isinstance(X, np.ndarray): @@ -205,8 +213,6 @@ def fit(self, X, y, queue=None): X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) - policy = self._get_policy(queue, X, y) - self.n_features_in_ = _num_features(X, fallback_1d=True) X_table, y_table = to_table(X, y, queue=queue) @@ -214,9 +220,9 @@ def fit(self, X, y, queue=None): hparams = get_hyperparameters("linear_regression", "train") if hparams is not None and not hparams.is_default: - result = module.train(policy, params, hparams.backend, X_table, y_table) + result = self.train(params, hparams.backend, X_table, y_table) else: - result = module.train(policy, params, X_table, y_table) + result = self.train(params, X_table, y_table) self._onedal_model = result.model @@ -269,6 +275,7 @@ def __init__( fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm ) + @supports_queue def fit(self, X, y, queue=None): """ Fit linear model. @@ -288,8 +295,13 @@ def fit(self, X, y, queue=None): self : object Fitted Estimator. """ - module = self._get_backend("linear_model", "regression") - + X = _check_array( + X, + dtype=[np.float64, np.float32], + force_all_finite=False, + ensure_2d=False, + copy=self.copy_X, + ) if not isinstance(X, np.ndarray): X = np.asarray(X) @@ -302,14 +314,12 @@ def fit(self, X, y, queue=None): X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True) - policy = self._get_policy(queue, X, y) - self.n_features_in_ = _num_features(X, fallback_1d=True) X_table, y_table = to_table(X, y, queue=queue) params = self._get_onedal_params(X.dtype) - result = module.train(policy, params, X_table, y_table) + result = self.train(params, X_table, y_table) self._onedal_model = result.model packed_coefficients = from_table(result.model.packed_coefficients) diff --git a/onedal/linear_model/logistic_regression.py b/onedal/linear_model/logistic_regression.py index 53e5f293ce..7175e797f0 100644 --- a/onedal/linear_model/logistic_regression.py +++ b/onedal/linear_model/logistic_regression.py @@ -20,12 +20,13 @@ import numpy as np from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend -from ..common._base import BaseEstimator as onedal_BaseEstimator from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin from ..datatypes import from_table, to_table -from ..utils import ( +from ..utils.validation import ( _check_array, _check_n_features, _check_X_y, @@ -35,7 +36,7 @@ ) -class BaseLogisticRegression(onedal_BaseEstimator, metaclass=ABCMeta): +class BaseLogisticRegression(metaclass=ABCMeta): @abstractmethod def __init__(self, tol, C, fit_intercept, solver, max_iter, algorithm): self.tol = tol @@ -45,6 +46,16 @@ def __init__(self, tol, C, fit_intercept, solver, max_iter, algorithm): self.max_iter = max_iter self.algorithm = algorithm + @abstractmethod + def train(self, params, X, y): ... + + @abstractmethod + def infer(self, params, X): ... + + # direct access to the backend model constructor + @abstractmethod + def model(self): ... + def _get_onedal_params(self, is_csr, dtype=np.float32): intercept = "intercept|" if self.fit_intercept else "" return { @@ -62,7 +73,7 @@ def _get_onedal_params(self, is_csr, dtype=np.float32): ), } - def _fit(self, X, y, module, queue): + def _fit(self, X, y): sparsity_enabled = daal_check_version((2024, "P", 700)) X, y = _check_X_y( X, @@ -82,11 +93,10 @@ def _fit(self, X, y, module, queue): self.classes_, y = np.unique(y, return_inverse=True) y = y.astype(dtype=np.int32) - policy = self._get_policy(queue, X, y) - X_table, y_table = to_table(X, y, queue=queue) + X_table, y_table = to_table(X, y, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(is_csr, X_table.dtype) - result = module.train(policy, params, X_table, y_table) + result = self.train(params, X_table, y_table) self._onedal_model = result.model self.n_iter_ = np.array([result.iterations_count]) @@ -100,8 +110,8 @@ def _fit(self, X, y, module, queue): return self - def _create_model(self, module, policy): - m = module.model() + def _create_model(self): + m = self.model() coefficients = self.coef_ dtype = get_dtype(coefficients) @@ -144,14 +154,14 @@ def _create_model(self, module, policy): packed_coefficients[:, 0][:, np.newaxis] = intercept m.packed_coefficients = to_table( - packed_coefficients, queue=getattr(policy, "_queue", None) + packed_coefficients, queue=SyclQueueManager.get_global_queue() ) self._onedal_model = m return m - def _infer(self, X, module, queue): + def _infer(self, X): _check_is_fitted(self) sparsity_enabled = daal_check_version((2024, "P", 700)) @@ -167,34 +177,33 @@ def _infer(self, X, module, queue): _check_n_features(self, X, False) X = make2d(X) - policy = self._get_policy(queue, X) if hasattr(self, "_onedal_model"): model = self._onedal_model else: - model = self._create_model(module, policy) + model = self._create_model() - X_table = to_table(X, queue=queue) + X_table = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(is_csr, X.dtype) - result = module.infer(policy, params, model, X_table) + result = self.infer(params, model, X_table) return result - def _predict(self, X, module, queue): - result = self._infer(X, module, queue) + def _predict(self, X): + result = self._infer(X) y = from_table(result.responses) y = np.take(self.classes_, y.ravel(), axis=0) return y - def _predict_proba(self, X, module, queue): - result = self._infer(X, module, queue) + def _predict_proba(self, X): + result = self._infer(X) y = from_table(result.probabilities) y = y.reshape(-1, 1) return np.hstack([1 - y, y]) - def _predict_log_proba(self, X, module, queue): - y_proba = self._predict_proba(X, module, queue) + def _predict_log_proba(self, X): + y_proba = self._predict_proba(X) return np.log(y_proba) @@ -223,25 +232,27 @@ def __init__( algorithm=algorithm, ) + @bind_default_backend("logistic_regression.classification") + def train(self, params, X, y, queue=None): ... + + @bind_default_backend("logistic_regression.classification") + def infer(self, params, X, model, queue=None): ... + + @bind_default_backend("logistic_regression.classification") + def model(self): ... + + @supports_queue def fit(self, X, y, queue=None): - return super()._fit( - X, y, self._get_backend("logistic_regression", "classification", None), queue - ) + return self._fit(X, y) + @supports_queue def predict(self, X, queue=None): - y = super()._predict( - X, self._get_backend("logistic_regression", "classification", None), queue - ) - return y + return self._predict(X) + @supports_queue def predict_proba(self, X, queue=None): - y = super()._predict_proba( - X, self._get_backend("logistic_regression", "classification", None), queue - ) - return y + return self._predict_proba(X) + @supports_queue def predict_log_proba(self, X, queue=None): - y = super()._predict_log_proba( - X, self._get_backend("logistic_regression", "classification", None), queue - ) - return y + return self._predict_log_proba(X) diff --git a/onedal/neighbors/neighbors.py b/onedal/neighbors/neighbors.py index b97706e49a..815ab7148c 100755 --- a/onedal/neighbors/neighbors.py +++ b/onedal/neighbors/neighbors.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from numbers import Integral import numpy as np @@ -27,12 +27,13 @@ kdtree_knn_classification_prediction, kdtree_knn_classification_training, ) +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend -from ..common._base import BaseEstimator from ..common._estimator_checks import _check_is_fitted, _is_classifier, _is_regressor from ..common._mixin import ClassifierMixin, RegressorMixin from ..datatypes import from_table, to_table -from ..utils import ( +from ..utils.validation import ( _check_array, _check_classification_targets, _check_n_features, @@ -42,7 +43,18 @@ ) -class NeighborsCommonBase(BaseEstimator, metaclass=ABCMeta): +class NeighborsCommonBase(metaclass=ABCMeta): + def __init__(self): + self.requires_y = False + self.n_neighbors = None + self.metric = None + self.classes_ = None + self.effective_metric_ = None + self._fit_method = None + self.radius = None + self.effective_metric_params_ = None + self._onedal_model = None + def _parse_auto_method(self, method, n_samples, n_features): result_method = method @@ -60,8 +72,17 @@ def _parse_auto_method(self, method, n_samples, n_features): return result_method + @abstractmethod + def train(self, *args, **kwargs): ... + + @abstractmethod + def infer(self, *args, **kwargs): ... + + @abstractmethod + def _onedal_fit(self, X, y): ... + def _validate_data( - self, X, y=None, reset=True, validate_separately=False, **check_params + self, X, y=None, reset=True, validate_separately=None, **check_params ): if y is None: if self.requires_y: @@ -188,13 +209,13 @@ def _validate_targets(self, y, dtype): return arr def _validate_n_classes(self): - if len(self.classes_) < 2: + length = 0 if self.classes_ is None else len(self.classes_) + if length < 2: raise ValueError( - "The number of classes has to be greater than one; got %d" - " class" % len(self.classes_) + f"The number of classes has to be greater than one; got {length}" ) - def _fit(self, X, y, queue): + def _fit(self, X, y): self._onedal_model = None self._tree = None self._shape = None @@ -253,11 +274,12 @@ def _fit(self, X, y, queue): ) _fit_y = None + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and queue.sycl_device.is_gpu if _is_classifier(self) or (_is_regressor(self) and gpu_device): _fit_y = self._validate_targets(self._y, X.dtype).reshape((-1, 1)) - result = self._onedal_fit(X, _fit_y, queue) + result = self._onedal_fit(X, _fit_y) if y is not None and _is_regressor(self): self._y = y if self._shape is None else y.reshape(self._shape) @@ -267,7 +289,7 @@ def _fit(self, X, y, queue): return result - def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): + def _kneighbors(self, X=None, n_neighbors=None, return_distance=True): n_features = getattr(self, "n_features_in_", None) shape = getattr(X, "shape", None) if n_features and shape and len(shape) > 1 and shape[1] != n_features: @@ -316,25 +338,21 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None ) chunked_results = None - method = super()._parse_auto_method( + method = self._parse_auto_method( self._fit_method, self.n_samples_fit_, n_features ) - if ( - type(self._onedal_model) is kdtree_knn_classification_model - or type(self._onedal_model) is bf_knn_classification_model + if type(self._onedal_model) in ( + kdtree_knn_classification_model, + bf_knn_classification_model, ): params = super()._get_daal_params(X, n_neighbors=n_neighbors) - prediction_results = self._onedal_predict( - self._onedal_model, X, params, queue=queue - ) + prediction_results = self._onedal_predict(self._onedal_model, X, params) distances = prediction_results.distances indices = prediction_results.indices else: params = super()._get_onedal_params(X, n_neighbors=n_neighbors) - prediction_results = self._onedal_predict( - self._onedal_model, X, params, queue=queue - ) + prediction_results = self._onedal_predict(self._onedal_model, X, params) distances = from_table(prediction_results.distances) indices = from_table(prediction_results.indices) @@ -408,14 +426,27 @@ def __init__( ) self.weights = weights + # direct access to the backend model constructor + @bind_default_backend("neighbors.classification") + def model(self): ... + + # direct access to the backend model constructor + @bind_default_backend("neighbors.classification") + def train(self, *args, **kwargs): ... + + @bind_default_backend("neighbors.classification") + def infer(self, *args, **kwargs): ... + def _get_daal_params(self, data): params = super()._get_daal_params(data) params["resultsToEvaluate"] = "computeClassLabels" params["resultsToCompute"] = "" return params - def _onedal_fit(self, X, y, queue): - gpu_device = queue is not None and queue.sycl_device.is_gpu + def _onedal_fit(self, X, y): + # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function + queue = SyclQueueManager.get_global_queue() + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X) if self._fit_method == "brute": @@ -425,42 +456,30 @@ def _onedal_fit(self, X, y, queue): train_alg = kdtree_knn_classification_training return train_alg(**params).compute(X, y).model + else: + params = self._get_onedal_params(X, y) + X_table, y_table = to_table(X, y, queue=queue) + return self.train(params, X_table, y_table).model - policy = self._get_policy(queue, X, y) - X_table, y_table = to_table(X, y, queue=queue) - params = self._get_onedal_params(X_table, y) - train_alg = self._get_backend( - "neighbors", "classification", "train", policy, params, X_table, y_table - ) - - return train_alg.model - - def _onedal_predict(self, model, X, params, queue): + def _onedal_predict(self, model, X, params): if type(self._onedal_model) is kdtree_knn_classification_model: return kdtree_knn_classification_prediction(**params).compute(X, model) elif type(self._onedal_model) is bf_knn_classification_model: return bf_knn_classification_prediction(**params).compute(X, model) - - policy = self._get_policy(queue, X) - X = to_table(X, queue=queue) - if hasattr(self, "_onedal_model"): - model = self._onedal_model else: - model = self._create_model( - self._get_backend("neighbors", "classification", None) - ) - if "responses" not in params["result_option"]: - params["result_option"] += "|responses" - params["fptype"] = X.dtype - result = self._get_backend( - "neighbors", "classification", "infer", policy, params, model, X - ) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) + if "responses" not in params["result_option"]: + params["result_option"] += "|responses" + params["fptype"] = X.dtype + result = self.infer(params, model, X) - return result + return result + @supports_queue def fit(self, X, y, queue=None): - return super()._fit(X, y, queue=queue) + return self._fit(X, y) + @supports_queue def predict(self, X, queue=None): X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) onedal_model = getattr(self, "_onedal_model", None) @@ -478,7 +497,7 @@ def predict(self, X, queue=None): _check_is_fitted(self) - self._fit_method = super()._parse_auto_method( + self._fit_method = self._parse_auto_method( self.algorithm, n_samples_fit_, n_features ) @@ -489,16 +508,17 @@ def predict(self, X, queue=None): or type(onedal_model) is bf_knn_classification_model ): params = self._get_daal_params(X) - prediction_result = self._onedal_predict(onedal_model, X, params, queue=queue) + prediction_result = self._onedal_predict(onedal_model, X, params) responses = prediction_result.prediction else: params = self._get_onedal_params(X) - prediction_result = self._onedal_predict(onedal_model, X, params, queue=queue) + prediction_result = self._onedal_predict(onedal_model, X, params) responses = from_table(prediction_result.responses) result = self.classes_.take(np.asarray(responses.ravel(), dtype=np.intp)) return result + @supports_queue def predict_proba(self, X, queue=None): neigh_dist, neigh_ind = self.kneighbors(X, queue=queue) @@ -536,8 +556,9 @@ def predict_proba(self, X, queue=None): return probabilities + @supports_queue def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): - return super()._kneighbors(X, n_neighbors, return_distance, queue=queue) + return self._kneighbors(X, n_neighbors, return_distance) class KNeighborsRegressor(NeighborsBase, RegressorMixin): @@ -562,67 +583,76 @@ def __init__( ) self.weights = weights + @bind_default_backend("neighbors.search", lookup_name="train") + def train_search(self, *args, **kwargs): ... + + @bind_default_backend("neighbors.search", lookup_name="infer") + def infer_search(self, *args, **kwargs): ... + + @bind_default_backend("neighbors.regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("neighbors.regression") + def infer(self, *args, **kwargs): ... + def _get_daal_params(self, data): params = super()._get_daal_params(data) params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances" params["resultsToEvaluate"] = "none" return params - def _onedal_fit(self, X, y, queue): - gpu_device = queue is not None and queue.sycl_device.is_gpu + def _onedal_fit(self, X, y): + # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function + queue = SyclQueueManager.get_global_queue() + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X) if self._fit_method == "brute": train_alg = bf_knn_classification_training - else: train_alg = kdtree_knn_classification_training return train_alg(**params).compute(X, y).model - policy = self._get_policy(queue, X, y) X_table, y_table = to_table(X, y, queue=queue) params = self._get_onedal_params(X_table, y) - train_alg_regr = self._get_backend("neighbors", "regression", None) - train_alg_srch = self._get_backend("neighbors", "search", None) if gpu_device: - return train_alg_regr.train(policy, params, X_table, y_table).model - return train_alg_srch.train(policy, params, X_table).model + return self.train(params, X_table, y_table).model + else: + return self.train_search(params, X_table).model + + def _onedal_predict(self, model, X, params): + assert self._onedal_model is not None, "Model is not trained" - def _onedal_predict(self, model, X, params, queue): if type(model) is kdtree_knn_classification_model: return kdtree_knn_classification_prediction(**params).compute(X, model) elif type(model) is bf_knn_classification_model: return bf_knn_classification_prediction(**params).compute(X, model) - gpu_device = queue is not None and queue.sycl_device.is_gpu - policy = self._get_policy(queue, X) + # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function + queue = SyclQueueManager.get_global_queue() + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) X = to_table(X, queue=queue) - backend = ( - self._get_backend("neighbors", "regression", None) - if gpu_device - else self._get_backend("neighbors", "search", None) - ) - if hasattr(self, "_onedal_model"): - model = self._onedal_model - else: - model = self._create_model(backend) if "responses" not in params["result_option"] and gpu_device: params["result_option"] += "|responses" params["fptype"] = X.dtype - result = backend.infer(policy, params, model, X) - return result + if gpu_device: + return self.infer(params, self._onedal_model, X) + else: + return self.infer_search(params, self._onedal_model, X) + @supports_queue def fit(self, X, y, queue=None): - return super()._fit(X, y, queue=queue) + return self._fit(X, y) + @supports_queue def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): - return super()._kneighbors(X, n_neighbors, return_distance, queue=queue) + return self._kneighbors(X, n_neighbors, return_distance) - def _predict_gpu(self, X, queue=None): + def _predict_gpu(self, X): X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) onedal_model = getattr(self, "_onedal_model", None) n_features = getattr(self, "n_features_in_", None) @@ -639,20 +669,20 @@ def _predict_gpu(self, X, queue=None): _check_is_fitted(self) - self._fit_method = super()._parse_auto_method( + self._fit_method = self._parse_auto_method( self.algorithm, n_samples_fit_, n_features ) params = self._get_onedal_params(X) - prediction_result = self._onedal_predict(onedal_model, X, params, queue=queue) + prediction_result = self._onedal_predict(onedal_model, X, params) responses = from_table(prediction_result.responses) result = responses.ravel() return result - def _predict_skl(self, X, queue=None): - neigh_dist, neigh_ind = self.kneighbors(X, queue=queue) + def _predict_skl(self, X): + neigh_dist, neigh_ind = self.kneighbors(X) weights = self._get_weights(neigh_dist, self.weights) @@ -675,14 +705,14 @@ def _predict_skl(self, X, queue=None): return y_pred + @supports_queue def predict(self, X, queue=None): - gpu_device = queue is not None and queue.sycl_device.is_gpu + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) is_uniform_weights = getattr(self, "weights", "uniform") == "uniform" - return ( - self._predict_gpu(X, queue=queue) - if gpu_device and is_uniform_weights - else self._predict_skl(X, queue=queue) - ) + if gpu_device and is_uniform_weights: + return self._predict_gpu(X) + else: + return self._predict_skl(X) class NearestNeighbors(NeighborsBase): @@ -707,6 +737,12 @@ def __init__( ) self.weights = weights + @bind_default_backend("neighbors.search") + def train(self, *args, **kwargs): ... + + @bind_default_backend("neighbors.search") + def infer(self, *arg, **kwargs): ... + def _get_daal_params(self, data): params = super()._get_daal_params(data) params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances" @@ -715,8 +751,10 @@ def _get_daal_params(self, data): ) return params - def _onedal_fit(self, X, y, queue): - gpu_device = queue is not None and queue.sycl_device.is_gpu + def _onedal_fit(self, X, y): + # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function + queue = SyclQueueManager.get_global_queue() + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X) if self._fit_method == "brute": @@ -727,37 +765,26 @@ def _onedal_fit(self, X, y, queue): return train_alg(**params).compute(X, y).model - policy = self._get_policy(queue, X, y) - X_table = to_table(X, queue=queue) - params = self._get_onedal_params(X_table, y) - train_alg = self._get_backend( - "neighbors", "search", "train", policy, params, X_table - ) - - return train_alg.model + else: + params = self._get_onedal_params(X, y) + X, y = to_table(X, y, queue=queue) + return self.train(params, X).model - def _onedal_predict(self, model, X, params, queue): + def _onedal_predict(self, model, X, params): if type(self._onedal_model) is kdtree_knn_classification_model: return kdtree_knn_classification_prediction(**params).compute(X, model) elif type(self._onedal_model) is bf_knn_classification_model: return bf_knn_classification_prediction(**params).compute(X, model) - policy = self._get_policy(queue, X) - X = to_table(X, queue=queue) - if hasattr(self, "_onedal_model"): - model = self._onedal_model - else: - model = self._create_model(self._get_backend("neighbors", "search", None)) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) params["fptype"] = X.dtype - result = self._get_backend( - "neighbors", "search", "infer", policy, params, model, X - ) - - return result + return self.infer(params, model, X) + @supports_queue def fit(self, X, y, queue=None): - return super()._fit(X, y, queue=queue) + return self._fit(X, y) + @supports_queue def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): - return super()._kneighbors(X, n_neighbors, return_distance, queue=queue) + return self._kneighbors(X, n_neighbors, return_distance) diff --git a/onedal/primitives/get_tree.py b/onedal/primitives/get_tree.py index 688d5d2983..4bd3b984df 100644 --- a/onedal/primitives/get_tree.py +++ b/onedal/primitives/get_tree.py @@ -14,12 +14,14 @@ # limitations under the License. # ============================================================================== -from onedal import _backend +from onedal import _default_backend, _dpc_backend + +backend = _dpc_backend or _default_backend def get_tree_state_cls(model, iTree, n_classes): - return _backend.get_tree.classification.get_tree_state(model, iTree, n_classes) + return backend.get_tree.classification.get_tree_state(model, iTree, n_classes) def get_tree_state_reg(model, iTree, n_classes): - return _backend.get_tree.regression.get_tree_state(model, iTree, n_classes) + return backend.get_tree.regression.get_tree_state(model, iTree, n_classes) diff --git a/onedal/primitives/kernel_functions.py b/onedal/primitives/kernel_functions.py index d48b326a16..5d4240ef40 100644 --- a/onedal/primitives/kernel_functions.py +++ b/onedal/primitives/kernel_functions.py @@ -14,13 +14,16 @@ # limitations under the License. # =============================================================================== +import queue + import numpy as np -from onedal import _backend +from onedal import _default_backend as backend +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import BackendFunction -from ..common._policy import _get_policy from ..datatypes import from_table, to_table -from ..utils import _check_array +from ..utils.validation import _check_array def _check_inputs(X, Y): @@ -32,14 +35,20 @@ def check_input(data): return X, Y -def _compute_kernel(params, submodule, X, Y, queue): - policy = _get_policy(queue, X, Y) +def _compute_kernel(params, submodule, X, Y): + # get policy for direct backend calls + + queue = SyclQueueManager.get_global_queue() X, Y = to_table(X, Y, queue=queue) params["fptype"] = X.dtype - result = submodule.compute(policy, params, X, Y) + compute_method = BackendFunction( + submodule.compute, backend, "compute", no_policy=False + ) + result = compute_method(params, X, Y) return from_table(result.values) +@supports_queue def linear_kernel(X, Y=None, scale=1.0, shift=0.0, queue=None): """ Compute the linear kernel between X and Y: @@ -60,13 +69,13 @@ def linear_kernel(X, Y=None, scale=1.0, shift=0.0, queue=None): X, Y = _check_inputs(X, Y) return _compute_kernel( {"method": "dense", "scale": scale, "shift": shift}, - _backend.linear_kernel, + backend.linear_kernel, X, Y, - queue, ) +@supports_queue def rbf_kernel(X, Y=None, gamma=None, queue=None): """ Compute the rbf (gaussian) kernel between X and Y: @@ -90,11 +99,10 @@ def rbf_kernel(X, Y=None, gamma=None, queue=None): gamma = 1.0 / X.shape[1] if gamma is None else gamma sigma = np.sqrt(0.5 / gamma) - return _compute_kernel( - {"method": "dense", "sigma": sigma}, _backend.rbf_kernel, X, Y, queue - ) + return _compute_kernel({"method": "dense", "sigma": sigma}, backend.rbf_kernel, X, Y) +@supports_queue def poly_kernel(X, Y=None, gamma=1.0, coef0=0.0, degree=3, queue=None): """ Compute the poly kernel between X and Y: @@ -117,13 +125,13 @@ def poly_kernel(X, Y=None, gamma=1.0, coef0=0.0, degree=3, queue=None): X, Y = _check_inputs(X, Y) return _compute_kernel( {"method": "dense", "scale": gamma, "shift": coef0, "degree": degree}, - _backend.polynomial_kernel, + backend.polynomial_kernel, X, Y, - queue, ) +@supports_queue def sigmoid_kernel(X, Y=None, gamma=1.0, coef0=0.0, queue=None): """ Compute the sigmoid kernel between X and Y: @@ -144,9 +152,5 @@ def sigmoid_kernel(X, Y=None, gamma=1.0, coef0=0.0, queue=None): X, Y = _check_inputs(X, Y) return _compute_kernel( - {"method": "dense", "scale": gamma, "shift": coef0}, - _backend.sigmoid_kernel, - X, - Y, - queue, + {"method": "dense", "scale": gamma, "shift": coef0}, backend.sigmoid_kernel, X, Y ) diff --git a/onedal/primitives/tests/test_kernel_functions.py b/onedal/primitives/tests/test_kernel_functions.py index 22a8f562cb..9becc976b4 100644 --- a/onedal/primitives/tests/test_kernel_functions.py +++ b/onedal/primitives/tests/test_kernel_functions.py @@ -91,7 +91,7 @@ def test_dense_small_rbf_kernel(queue, gamma, dtype): _test_dense_small_rbf_kernel(queue, gamma, dtype) -@pass_if_not_implemented_for_gpu(reason="poly kernel is not implemented") +@pass_if_not_implemented_for_gpu(reason="Polynomial kernel is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_dense_self_poly_kernel(queue): rng = np.random.RandomState(0) @@ -116,7 +116,7 @@ def _test_dense_small_poly_kernel(queue, gamma, coef0, degree, dtype): assert_allclose(result, expected, rtol=tol) -@pass_if_not_implemented_for_gpu(reason="poly kernel is not implemented") +@pass_if_not_implemented_for_gpu(reason="Polynomial kernel is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("gamma", [0.1, 1.0]) @pytest.mark.parametrize("coef0", [0.0, 1.0]) @@ -126,7 +126,7 @@ def test_dense_small_poly_kernel(queue, gamma, coef0, degree, dtype): _test_dense_small_poly_kernel(queue, gamma, coef0, degree, dtype) -@pass_if_not_implemented_for_gpu(reason="sigmoid kernel is not implemented") +@pass_if_not_implemented_for_gpu(reason="Sigmoid kernel is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_dense_self_sigmoid_kernel(queue): rng = np.random.RandomState(0) @@ -150,7 +150,7 @@ def _test_dense_small_sigmoid_kernel(queue, gamma, coef0, dtype): assert_allclose(result, expected, rtol=tol) -@pass_if_not_implemented_for_gpu(reason="sigmoid kernel is not implemented") +@pass_if_not_implemented_for_gpu(reason="Sigmoid kernel is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("gamma", [0.1, 1.0, 2.4]) @pytest.mark.parametrize("coef0", [0.0, 1.0, 5.5]) diff --git a/onedal/spmd/__init__.py b/onedal/spmd/__init__.py index 2c60cc2353..cbe5944f67 100644 --- a/onedal/spmd/__init__.py +++ b/onedal/spmd/__init__.py @@ -14,6 +14,16 @@ # limitations under the License. # ============================================================================== +from . import ( + basic_statistics, + cluster, + covariance, + decomposition, + ensemble, + linear_model, + neighbors, +) + __all__ = [ "basic_statistics", "cluster", diff --git a/onedal/spmd/_base.py b/onedal/spmd/_base.py deleted file mode 100644 index 52307ddb34..0000000000 --- a/onedal/spmd/_base.py +++ /dev/null @@ -1,30 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from abc import ABC - -from onedal import _spmd_backend - -from ..common._base import _get_backend -from ..common._spmd_policy import _get_spmd_policy - - -class BaseEstimatorSPMD(ABC): - def _get_backend(self, module, submodule=None, method=None, *args, **kwargs): - return _get_backend(_spmd_backend, module, submodule, method, *args, **kwargs) - - def _get_policy(self, queue, *data): - return _get_spmd_policy(queue) diff --git a/onedal/spmd/basic_statistics/basic_statistics.py b/onedal/spmd/basic_statistics/basic_statistics.py index 8253aa6628..72df4d778e 100644 --- a/onedal/spmd/basic_statistics/basic_statistics.py +++ b/onedal/spmd/basic_statistics/basic_statistics.py @@ -14,17 +14,15 @@ # limitations under the License. # ============================================================================== -from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...basic_statistics import BasicStatistics as BasicStatistics_Batch +from ...common._backend import bind_spmd_backend -class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch): - @support_input_format() - def compute(self, data, weights=None, queue=None): - return super().compute(data, weights=weights, queue=queue) +class BasicStatistics(BasicStatistics_Batch): + @bind_spmd_backend("basic_statistics") + def compute(self, data, weights=None): ... - @support_input_format() + @support_input_format def fit(self, data, sample_weight=None, queue=None): - return super().fit(data, sample_weight=sample_weight, queue=queue) + return super().fit(data, sample_weight, queue=queue) diff --git a/onedal/spmd/basic_statistics/incremental_basic_statistics.py b/onedal/spmd/basic_statistics/incremental_basic_statistics.py index fe4e57b199..6eaf3d0b85 100644 --- a/onedal/spmd/basic_statistics/incremental_basic_statistics.py +++ b/onedal/spmd/basic_statistics/incremental_basic_statistics.py @@ -14,58 +14,16 @@ # limitations under the License. # ============================================================================== -from daal4py.sklearn._utils import get_dtype +from onedal.common._backend import bind_spmd_backend from ...basic_statistics import ( IncrementalBasicStatistics as base_IncrementalBasicStatistics, ) -from ...datatypes import to_table -from .._base import BaseEstimatorSPMD -class IncrementalBasicStatistics(BaseEstimatorSPMD, base_IncrementalBasicStatistics): - def _reset(self): - self._need_to_finalize = False - self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend( - "basic_statistics", None, "partial_compute_result" - ) +class IncrementalBasicStatistics(base_IncrementalBasicStatistics): + @bind_spmd_backend("basic_statistics") + def compute(self, *args, **kwargs): ... - def partial_fit(self, X, weights=None, queue=None): - """ - Computes partial data for basic statistics - from data batch X and saves it to `_partial_result`. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data batch, where `n_samples` is the number of samples - in the batch, and `n_features` is the number of features. - - queue : dpctl.SyclQueue - If not None, use this queue for computations. - - Returns - ------- - self : object - Returns the instance itself. - """ - self._queue = queue - policy = super(base_IncrementalBasicStatistics, self)._get_policy(queue, X) - X_table, weights_table = to_table(X, weights, queue=queue) - - if not hasattr(self, "_onedal_params"): - self._onedal_params = self._get_onedal_params(False, dtype=X_table.dtype) - - self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend( - "basic_statistics", - None, - "partial_compute", - policy, - self._onedal_params, - self._partial_result, - X_table, - weights_table, - ) - - self._need_to_finalize = True - return self + @bind_spmd_backend("basic_statistics") + def finalize_compute(self, *args, **kwargs): ... diff --git a/onedal/spmd/cluster/__init__.py b/onedal/spmd/cluster/__init__.py index bb7a0b3a06..94b5385367 100644 --- a/onedal/spmd/cluster/__init__.py +++ b/onedal/spmd/cluster/__init__.py @@ -18,11 +18,9 @@ from .dbscan import DBSCAN +__all__ = ["DBSCAN"] + if daal_check_version((2023, "P", 200)): from .kmeans import KMeans - __all__ = ["DBSCAN", "KMeans"] -else: - __all__ = [ - "DBSCAN", - ] + __all__ += ["KMeans"] diff --git a/onedal/spmd/cluster/dbscan.py b/onedal/spmd/cluster/dbscan.py index 1460ed6533..0281b5b1bd 100644 --- a/onedal/spmd/cluster/dbscan.py +++ b/onedal/spmd/cluster/dbscan.py @@ -14,10 +14,10 @@ # limitations under the License. # ============================================================================== -from onedal.cluster import DBSCAN as DBSCAN_Batch +from ...cluster import DBSCAN as DBSCAN_Batch +from ...common._backend import bind_spmd_backend -from .._base import BaseEstimatorSPMD - -class DBSCAN(BaseEstimatorSPMD, DBSCAN_Batch): - pass +class DBSCAN(DBSCAN_Batch): + @bind_spmd_backend("dbscan.clustering") + def compute(self, params, data_table, weights_table): ... diff --git a/onedal/spmd/cluster/kmeans.py b/onedal/spmd/cluster/kmeans.py index 3f552a353b..ebd0c55827 100644 --- a/onedal/spmd/cluster/kmeans.py +++ b/onedal/spmd/cluster/kmeans.py @@ -14,43 +14,43 @@ # limitations under the License. # ============================================================================== -from onedal.cluster import KMeans as KMeans_Batch -from onedal.cluster import KMeansInit as KMeansInit_Batch -from onedal.spmd.basic_statistics import BasicStatistics - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...cluster import KMeans as KMeans_Batch +from ...cluster import KMeansInit as KMeansInit_Batch +from ...common._backend import bind_spmd_backend +from ...spmd.basic_statistics import BasicStatistics -class KMeansInit(BaseEstimatorSPMD, KMeansInit_Batch): +class KMeansInit(KMeansInit_Batch): """ KMeansInit oneDAL implementation for SPMD iface. """ - pass + @bind_spmd_backend("kmeans_init.init", lookup_name="compute") + def backend_compute(self, params, data): ... -class KMeans(BaseEstimatorSPMD, KMeans_Batch): +class KMeans(KMeans_Batch): def _get_basic_statistics_backend(self, result_options): return BasicStatistics(result_options) def _get_kmeans_init(self, cluster_count, seed, algorithm): return KMeansInit(cluster_count=cluster_count, seed=seed, algorithm=algorithm) - @support_input_format() + @bind_spmd_backend("kmeans.clustering") + def train(self, params, X_table, centroids_table): ... + + @bind_spmd_backend("kmeans.clustering") + def infer(self, params, model, centroids_table): ... + + @support_input_format def fit(self, X, y=None, queue=None): - return super().fit(X, queue=queue) + return super().fit(X, y, queue=queue) - @support_input_format() + @support_input_format def predict(self, X, queue=None): return super().predict(X, queue=queue) - @support_input_format() + @support_input_format def fit_predict(self, X, y=None, queue=None): return super().fit_predict(X, queue=queue) - - def transform(self, X): - return super().transform(X) - - def fit_transform(self, X, queue=None): - return super().fit_transform(X, queue=queue) diff --git a/onedal/spmd/covariance/covariance.py b/onedal/spmd/covariance/covariance.py index fe746b0993..d007cb88d7 100644 --- a/onedal/spmd/covariance/covariance.py +++ b/onedal/spmd/covariance/covariance.py @@ -14,13 +14,19 @@ # limitations under the License. # ============================================================================== -from onedal.covariance import EmpiricalCovariance as EmpiricalCovariance_Batch - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...common._backend import bind_spmd_backend +from ...covariance import EmpiricalCovariance as EmpiricalCovariance_Batch + + +class EmpiricalCovariance(EmpiricalCovariance_Batch): + + @bind_spmd_backend("covariance") + def compute(self, *args, **kwargs): ... + @bind_spmd_backend("covariance") + def finalize_compute(self, params, partial_result): ... -class EmpiricalCovariance(BaseEstimatorSPMD, EmpiricalCovariance_Batch): - @support_input_format() + @support_input_format def fit(self, X, y=None, queue=None): return super().fit(X, queue=queue) diff --git a/onedal/spmd/covariance/incremental_covariance.py b/onedal/spmd/covariance/incremental_covariance.py index f8d25b2a08..f077b4c1bf 100644 --- a/onedal/spmd/covariance/incremental_covariance.py +++ b/onedal/spmd/covariance/incremental_covariance.py @@ -14,70 +14,12 @@ # limitations under the License. # ============================================================================== -import numpy as np - -from daal4py.sklearn._utils import get_dtype - +from ...common._backend import bind_spmd_backend from ...covariance import ( IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance, ) -from ...datatypes import to_table -from ...utils import _check_array -from .._base import BaseEstimatorSPMD - - -class IncrementalEmpiricalCovariance( - BaseEstimatorSPMD, base_IncrementalEmpiricalCovariance -): - def _reset(self): - self._need_to_finalize = False - self._partial_result = super( - base_IncrementalEmpiricalCovariance, self - )._get_backend("covariance", None, "partial_compute_result") - - def partial_fit(self, X, y=None, queue=None): - """ - Computes partial data for the covariance matrix - from data batch X and saves it to `_partial_result`. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data batch, where `n_samples` is the number of samples - in the batch, and `n_features` is the number of features. - - y : Ignored - Not used, present for API consistency by convention. - - queue : dpctl.SyclQueue - If not None, use this queue for computations. - - Returns - ------- - self : object - Returns the instance itself. - """ - X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) - - self._queue = queue - - policy = super(base_IncrementalEmpiricalCovariance, self)._get_policy(queue, X) - - X_table = to_table(X, queue=queue) - if not hasattr(self, "_dtype"): - self._dtype = X_table.dtype - params = self._get_onedal_params(self._dtype) - self._partial_result = super( - base_IncrementalEmpiricalCovariance, self - )._get_backend( - "covariance", - None, - "partial_compute", - policy, - params, - self._partial_result, - X_table, - ) - self._need_to_finalize = True +class IncrementalEmpiricalCovariance(base_IncrementalEmpiricalCovariance): + @bind_spmd_backend("covariance") + def finalize_compute(self, params, partial_result): ... diff --git a/onedal/spmd/decomposition/incremental_pca.py b/onedal/spmd/decomposition/incremental_pca.py index 76c3821d52..bb7c03930b 100644 --- a/onedal/spmd/decomposition/incremental_pca.py +++ b/onedal/spmd/decomposition/incremental_pca.py @@ -14,15 +14,12 @@ # limitations under the License. # ============================================================================== -from daal4py.sklearn._utils import get_dtype +from onedal.common._backend import bind_spmd_backend -from ...datatypes import from_table, to_table from ...decomposition import IncrementalPCA as base_IncrementalPCA -from ...utils import _check_array -from .._base import BaseEstimatorSPMD -class IncrementalPCA(BaseEstimatorSPMD, base_IncrementalPCA): +class IncrementalPCA(base_IncrementalPCA): """ Distributed incremental estimator for PCA based on oneDAL implementation. Allows for distributed PCA computation if data is split into batches. @@ -30,95 +27,5 @@ class IncrementalPCA(BaseEstimatorSPMD, base_IncrementalPCA): API is the same as for `onedal.decomposition.IncrementalPCA` """ - def _reset(self): - self._need_to_finalize = False - self._partial_result = super(base_IncrementalPCA, self)._get_backend( - "decomposition", "dim_reduction", "partial_train_result" - ) - if hasattr(self, "components_"): - del self.components_ - - def partial_fit(self, X, y=None, queue=None): - """Incremental fit with X. All of X is processed as a single batch. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data, where `n_samples` is the number of samples and - `n_features` is the number of features. - - y : Ignored - Not used, present for API consistency by convention. - - Returns - ------- - self : object - Returns the instance itself. - """ - X = _check_array(X) - n_samples, n_features = X.shape - - first_pass = not hasattr(self, "components_") - if first_pass: - self.components_ = None - self.n_samples_seen_ = n_samples - self.n_features_in_ = n_features - else: - self.n_samples_seen_ += n_samples - - if self.n_components is None: - if self.components_ is None: - self.n_components_ = min(n_samples, n_features) - else: - self.n_components_ = self.components_.shape[0] - else: - self.n_components_ = self.n_components - - self._queue = queue - - policy = super(base_IncrementalPCA, self)._get_policy(queue, X) - X_table = to_table(X, queue=queue) - - if not hasattr(self, "_dtype"): - self._dtype = X_table.dtype - self._params = self._get_onedal_params(X_table) - - self._partial_result = super(base_IncrementalPCA, self)._get_backend( - "decomposition", - "dim_reduction", - "partial_train", - policy, - self._params, - self._partial_result, - X_table, - ) - self._need_to_finalize = True - return self - - def _create_model(self): - m = super(base_IncrementalPCA, self)._get_backend( - "decomposition", "dim_reduction", "model" - ) - m.eigenvectors = to_table(self.components_) - m.means = to_table(self.mean_) - if self.whiten: - m.eigenvalues = to_table(self.explained_variance_) - self._onedal_model = m - return m - - def predict(self, X, queue=None): - policy = super(base_IncrementalPCA, self)._get_policy(queue, X) - model = self._create_model() - X = to_table(X, queue=queue) - params = self._get_onedal_params(X, stage="predict") - - result = super(base_IncrementalPCA, self)._get_backend( - "decomposition", - "dim_reduction", - "infer", - policy, - params, - model, - X, - ) - return from_table(result.transformed_data) + @bind_spmd_backend("decomposition.dim_reduction") + def finalize_train(self, params, partial_result, queue=None): ... diff --git a/onedal/spmd/decomposition/pca.py b/onedal/spmd/decomposition/pca.py index 55f242f782..d1442af0cc 100644 --- a/onedal/spmd/decomposition/pca.py +++ b/onedal/spmd/decomposition/pca.py @@ -14,13 +14,19 @@ # limitations under the License. # ============================================================================== -from onedal.decomposition.pca import PCA as PCABatch - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...common._backend import bind_spmd_backend +from ...decomposition.pca import PCA as PCABatch + + +class PCA(PCABatch): + + @bind_spmd_backend("decomposition.dim_reduction") + def train(self, params, X, queue=None): ... + @bind_spmd_backend("decomposition.dim_reduction") + def finalize_train(self, *args, **kwargs): ... -class PCA(BaseEstimatorSPMD, PCABatch): - @support_input_format() + @support_input_format def fit(self, X, y=None, queue=None): return super().fit(X, queue=queue) diff --git a/onedal/spmd/ensemble/__init__.py b/onedal/spmd/ensemble/__init__.py index 9068c7b255..caa541f9d5 100644 --- a/onedal/spmd/ensemble/__init__.py +++ b/onedal/spmd/ensemble/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # ============================================================================== -from .forest import RandomForestClassifier, RandomForestRegressor +from ...ensemble import RandomForestClassifier, RandomForestRegressor __all__ = ["RandomForestClassifier", "RandomForestRegressor"] diff --git a/onedal/spmd/ensemble/forest.py b/onedal/spmd/ensemble/forest.py deleted file mode 100644 index 90a3f924db..0000000000 --- a/onedal/spmd/ensemble/forest.py +++ /dev/null @@ -1,28 +0,0 @@ -# ============================================================================== -# Copyright 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from onedal.ensemble import RandomForestClassifier as RandomForestClassifier_Batch -from onedal.ensemble import RandomForestRegressor as RandomForestRegressor_Batch - -from .._base import BaseEstimatorSPMD - - -class RandomForestClassifier(BaseEstimatorSPMD, RandomForestClassifier_Batch): - pass - - -class RandomForestRegressor(BaseEstimatorSPMD, RandomForestRegressor_Batch): - pass diff --git a/onedal/spmd/linear_model/incremental_linear_model.py b/onedal/spmd/linear_model/incremental_linear_model.py index d3846bc82a..bfdc00c4b7 100644 --- a/onedal/spmd/linear_model/incremental_linear_model.py +++ b/onedal/spmd/linear_model/incremental_linear_model.py @@ -14,84 +14,20 @@ # limitations under the License. # ============================================================================== -import numpy as np -from daal4py.sklearn._utils import get_dtype +from onedal.common._backend import bind_spmd_backend -from ...common.hyperparameters import get_hyperparameters -from ...datatypes import to_table from ...linear_model import ( IncrementalLinearRegression as base_IncrementalLinearRegression, ) -from ...utils import _check_X_y, _num_features -from .._base import BaseEstimatorSPMD -class IncrementalLinearRegression(BaseEstimatorSPMD, base_IncrementalLinearRegression): +class IncrementalLinearRegression(base_IncrementalLinearRegression): """ Distributed incremental Linear Regression oneDAL implementation. API is the same as for `onedal.linear_model.IncrementalLinearRegression`. """ - def _reset(self): - self._partial_result = super(base_IncrementalLinearRegression, self)._get_backend( - "linear_model", "regression", "partial_train_result" - ) - - def partial_fit(self, X, y, queue=None): - """ - Computes partial data for linear regression - from data batch X and saves it to `_partial_result`. - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data batch, where `n_samples` is the number of samples - in the batch, and `n_features` is the number of features. - - y: array-like of shape (n_samples,) or (n_samples, n_targets) in - case of multiple targets - Responses for training data. - - queue : dpctl.SyclQueue - If not None, use this queue for computations. - Returns - ------- - self : object - Returns the instance itself. - """ - module = super(base_IncrementalLinearRegression, self)._get_backend( - "linear_model", "regression" - ) - - self._queue = queue - policy = super(base_IncrementalLinearRegression, self)._get_policy(queue, X) - - X, y = _check_X_y( - X, y, dtype=[np.float64, np.float32], accept_2d_y=True, force_all_finite=False - ) - - X_table, y_table = to_table(X, y, queue=queue) - - if not hasattr(self, "_dtype"): - self._dtype = X_table.dtype - self._params = self._get_onedal_params(self._dtype) - - y = np.asarray(y, dtype=self._dtype) - - self.n_features_in_ = _num_features(X, fallback_1d=True) - - hparams = get_hyperparameters("linear_regression", "train") - if hparams is not None and not hparams.is_default: - self._partial_result = module.partial_train( - policy, - self._params, - hparams.backend, - self._partial_result, - X_table, - y_table, - ) - else: - self._partial_result = module.partial_train( - policy, self._params, self._partial_result, X_table, y_table - ) + @bind_spmd_backend("linear_model.regression") + def finalize_train(self, *args, **kwargs): ... diff --git a/onedal/spmd/linear_model/linear_model.py b/onedal/spmd/linear_model/linear_model.py index 11d9cbe0e8..cbe3af8dc0 100644 --- a/onedal/spmd/linear_model/linear_model.py +++ b/onedal/spmd/linear_model/linear_model.py @@ -14,17 +14,26 @@ # limitations under the License. # ============================================================================== -from onedal.linear_model import LinearRegression as LinearRegression_Batch - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...common._backend import bind_spmd_backend +from ...linear_model import LinearRegression as LinearRegression_Batch + + +class LinearRegression(LinearRegression_Batch): + + @bind_spmd_backend("linear_model.regression") + def train(self, *args, **kwargs): ... + + @bind_spmd_backend("linear_model.regression") + def finalize_train(self, *args, **kwargs): ... + @bind_spmd_backend("linear_model.regression") + def infer(self, params, model, X): ... -class LinearRegression(BaseEstimatorSPMD, LinearRegression_Batch): - @support_input_format() + @support_input_format def fit(self, X, y, queue=None): return super().fit(X, y, queue=queue) - @support_input_format() + @support_input_format def predict(self, X, queue=None): return super().predict(X, queue=queue) diff --git a/onedal/spmd/linear_model/logistic_regression.py b/onedal/spmd/linear_model/logistic_regression.py index 38529eaef7..5dfed76b59 100644 --- a/onedal/spmd/linear_model/logistic_regression.py +++ b/onedal/spmd/linear_model/logistic_regression.py @@ -14,25 +14,31 @@ # limitations under the License. # ============================================================================== -from onedal.linear_model import LogisticRegression as LogisticRegression_Batch - from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +from ...common._backend import bind_spmd_backend +from ...linear_model import LogisticRegression as LogisticRegression_Batch + + +class LogisticRegression(LogisticRegression_Batch): + + @bind_spmd_backend("logistic_regression.classification") + def train(self, params, X, y): ... + @bind_spmd_backend("logistic_regression.classification") + def infer(self, params, X, model): ... -class LogisticRegression(BaseEstimatorSPMD, LogisticRegression_Batch): - @support_input_format() + @support_input_format def fit(self, X, y, queue=None): return super().fit(X, y, queue=queue) - @support_input_format() + @support_input_format def predict(self, X, queue=None): return super().predict(X, queue=queue) - @support_input_format() + @support_input_format def predict_proba(self, X, queue=None): return super().predict_proba(X, queue=queue) - @support_input_format() + @support_input_format def predict_log_proba(self, X, queue=None): return super().predict_log_proba(X, queue=queue) diff --git a/onedal/spmd/neighbors/__init__.py b/onedal/spmd/neighbors/__init__.py index 8036511d9f..1aa6247605 100644 --- a/onedal/spmd/neighbors/__init__.py +++ b/onedal/spmd/neighbors/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # ============================================================================== -from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors +from .neighbors import KNeighborsClassifier, KNeighborsRegressor -__all__ = ["KNeighborsClassifier", "KNeighborsRegressor", "NearestNeighbors"] +__all__ = ["KNeighborsClassifier", "KNeighborsRegressor"] diff --git a/onedal/spmd/neighbors/neighbors.py b/onedal/spmd/neighbors/neighbors.py index 87004e1a77..b9f5f98d18 100644 --- a/onedal/spmd/neighbors/neighbors.py +++ b/onedal/spmd/neighbors/neighbors.py @@ -14,62 +14,73 @@ # limitations under the License. # ============================================================================== -from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch -from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch +from ..._device_offload import support_input_format, supports_queue +from ...common._backend import bind_spmd_backend +from ...neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch +from ...neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch -from ..._device_offload import support_input_format -from .._base import BaseEstimatorSPMD +class KNeighborsClassifier(KNeighborsClassifier_Batch): -class KNeighborsClassifier(BaseEstimatorSPMD, KNeighborsClassifier_Batch): - @support_input_format() + @bind_spmd_backend("neighbors.classification") + def train(self, *args, **kwargs): ... + + @bind_spmd_backend("neighbors.classification") + def infer(self, *args, **kwargs): ... + + @support_input_format def fit(self, X, y, queue=None): return super().fit(X, y, queue=queue) - @support_input_format() + @support_input_format def predict(self, X, queue=None): return super().predict(X, queue=queue) - @support_input_format() + @support_input_format def predict_proba(self, X, queue=None): raise NotImplementedError("predict_proba not supported in distributed mode.") - @support_input_format() + @support_input_format def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): return super().kneighbors(X, n_neighbors, return_distance, queue=queue) -class KNeighborsRegressor(BaseEstimatorSPMD, KNeighborsRegressor_Batch): - @support_input_format() +class KNeighborsRegressor(KNeighborsRegressor_Batch): + + @bind_spmd_backend("neighbors.search", lookup_name="train") + def train_search(self, *args, **kwargs): ... + + @bind_spmd_backend("neighbors.search", lookup_name="infer") + def infer_search(self, *args, **kwargs): ... + + @bind_spmd_backend("neighbors.regression") + def train(self, *args, **kwargs): ... + + @bind_spmd_backend("neighbors.regression") + def infer(self, *args, **kwargs): ... + + @support_input_format + @supports_queue def fit(self, X, y, queue=None): if queue is not None and queue.sycl_device.is_gpu: - return super()._fit(X, y, queue=queue) + return self._fit(X, y) else: raise ValueError( "SPMD version of kNN is not implemented for " "CPU. Consider running on it on GPU." ) - @support_input_format() + @support_input_format def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): return super().kneighbors(X, n_neighbors, return_distance, queue=queue) - @support_input_format() + @support_input_format + @supports_queue def predict(self, X, queue=None): - return self._predict_gpu(X, queue=queue) + return self._predict_gpu(X) def _get_onedal_params(self, X, y=None): params = super()._get_onedal_params(X, y) if "responses" not in params["result_option"]: params["result_option"] += "|responses" return params - - -class NearestNeighbors(BaseEstimatorSPMD): - @support_input_format() - def fit(self, X, y, queue=None): - return super().fit(X, y, queue=queue) - - @support_input_format() - def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): - return super().kneighbors(X, n_neighbors, return_distance, queue=queue) diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py index a4c5046d0f..b0adfe14d6 100644 --- a/onedal/svm/svm.py +++ b/onedal/svm/svm.py @@ -20,13 +20,13 @@ import numpy as np from scipy import sparse as sp -from onedal import _backend +from onedal._device_offload import SyclQueueManager, supports_queue +from onedal.common._backend import bind_default_backend from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin, RegressorMixin -from ..common._policy import _get_policy from ..datatypes import from_table, to_table -from ..utils import ( +from ..utils.validation import ( _check_array, _check_n_features, _check_X_y, @@ -84,6 +84,16 @@ def __init__( self.algorithm = algorithm self.svm_type = svm_type + @abstractmethod + def train(self, *args, **kwargs): ... + + @abstractmethod + def infer(self, *args, **kwargs): ... + + def _is_classification(self): + """helper function to determine if infer method was loaded from a classification module""" + return hasattr(self.infer, "name") and "classification" in self.infer.name + def _validate_targets(self, y, dtype): self.class_weight_ = None self.classes_ = None @@ -114,7 +124,7 @@ def _get_onedal_params(self, data): "cache_size": self.cache_size, } - def _fit(self, X, y, sample_weight, module, queue): + def _fit(self, X, y, sample_weight): if hasattr(self, "decision_function_shape"): if self.decision_function_shape not in ("ovr", "ovo", None): raise ValueError( @@ -166,17 +176,15 @@ def _fit(self, X, y, sample_weight, module, queue): _gamma = 1.0 / X.shape[1] else: raise ValueError( - "When 'gamma' is a string, it should be either 'scale' or " - "'auto'. Got '{}' instead.".format(self.gamma) + f"When 'gamma' is a string, it should be either 'scale' or 'auto'. Got '{self.gamma}' instead." ) else: _gamma = self.gamma self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma) - policy = _get_policy(queue, *data) - data = to_table(*data, queue=queue) + data = to_table(*data, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(data[0]) - result = module.train(policy, params, *data) + result = self.train(params, *data) if self._sparse: self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T) @@ -200,8 +208,8 @@ def _fit(self, X, y, sample_weight, module, queue): self._onedal_model = result.model return self - def _create_model(self, module): - m = module.model() + def _create_model(self): + m = self.model() m.support_vectors = to_table(self.support_vectors_) m.coeffs = to_table(self.dual_coef_.T) @@ -211,14 +219,14 @@ def _create_model(self, module): m.first_class_response, m.second_class_response = 0, 1 return m - def _predict(self, X, module, queue): + def _predict(self, X): _check_is_fitted(self) if self.break_ties and self.decision_function_shape == "ovo": raise ValueError( "break_ties must be False when " "decision_function_shape is 'ovo'" ) - if module in [_backend.svm.classification, _backend.svm.nu_classification]: + if self._is_classification(): sv = self.support_vectors_ if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: raise ValueError( @@ -252,15 +260,14 @@ def _predict(self, X, module, queue): % type(self).__name__ ) - policy = _get_policy(queue, X) - X = to_table(X, queue=queue) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(X) if hasattr(self, "_onedal_model"): model = self._onedal_model else: model = self._create_model(module) - result = module.infer(policy, params, model, X) + result = self.infer(params, model, X) y = from_table(result.responses) return y @@ -283,7 +290,7 @@ def _ovr_decision_function(self, predictions, confidences, n_classes): ) return votes + transformed_confidences - def _decision_function(self, X, module, queue): + def _decision_function(self, X): _check_is_fitted(self) X = _check_array( X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse="csr" @@ -301,7 +308,7 @@ def _decision_function(self, X, module, queue): % type(self).__name__ ) - if module in [_backend.svm.classification, _backend.svm.nu_classification]: + if self._is_classification(): sv = self.support_vectors_ if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]: raise ValueError( @@ -309,15 +316,14 @@ def _decision_function(self, X, module, queue): f"of {self.__class__.__name__} was altered" ) - policy = _get_policy(queue, X) - X = to_table(X, queue=queue) + X = to_table(X, queue=SyclQueueManager.get_global_queue()) params = self._get_onedal_params(X) if hasattr(self, "_onedal_model"): model = self._onedal_model else: model = self._create_model(module) - result = module.infer(policy, params, model, X) + result = self.infer(params, model, X) decision_function = from_table(result.decision_function) if len(self.classes_) == 2: @@ -372,11 +378,22 @@ def __init__( ) self.svm_type = SVMtype.epsilon_svr + @bind_default_backend("svm.regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("svm.regression") + def infer(self, *args, **kwargs): ... + + @bind_default_backend("svm.regression") + def model(self): ... + + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return super()._fit(X, y, sample_weight, _backend.svm.regression, queue) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.regression, queue) + y = self._predict(X) return y.ravel() @@ -424,23 +441,35 @@ def __init__( ) self.svm_type = SVMtype.c_svc + @bind_default_backend("svm.classification") + def train(self, *args, **kwargs): ... + + @bind_default_backend("svm.classification") + def infer(self, *args, **kwargs): ... + + @bind_default_backend("svm.classification") + def model(self): ... + def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( y, self.class_weight, dtype ) return y + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return super()._fit(X, y, sample_weight, _backend.svm.classification, queue) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.classification, queue) + y = self._predict(X) if len(self.classes_) == 2: y = y.ravel() return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel() + @supports_queue def decision_function(self, X, queue=None): - return super()._decision_function(X, _backend.svm.classification, queue) + return self._decision_function(X) class NuSVR(RegressorMixin, BaseSVM): @@ -485,12 +514,22 @@ def __init__( ) self.svm_type = SVMtype.nu_svr + @bind_default_backend("svm.nu_regression") + def train(self, *args, **kwargs): ... + + @bind_default_backend("svm.nu_regression") + def infer(self, *args, **kwargs): ... + + @bind_default_backend("svm.nu_regression") + def model(self): ... + + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.nu_regression, queue) - return y.ravel() + return self._predict(X).ravel() class NuSVC(ClassifierMixin, BaseSVM): @@ -537,20 +576,32 @@ def __init__( ) self.svm_type = SVMtype.nu_svc + @bind_default_backend("svm.nu_classification") + def train(self, *args, **kwargs): ... + + @bind_default_backend("svm.nu_classification") + def infer(self, *args, **kwargs): ... + + @bind_default_backend("svm.nu_classification") + def model(self): ... + def _validate_targets(self, y, dtype): y, self.class_weight_, self.classes_ = _validate_targets( y, self.class_weight, dtype ) return y + @supports_queue def fit(self, X, y, sample_weight=None, queue=None): - return super()._fit(X, y, sample_weight, _backend.svm.nu_classification, queue) + return self._fit(X, y, sample_weight) + @supports_queue def predict(self, X, queue=None): - y = super()._predict(X, _backend.svm.nu_classification, queue) + y = self._predict(X) if len(self.classes_) == 2: y = y.ravel() return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel() + @supports_queue def decision_function(self, X, queue=None): - return super()._decision_function(X, _backend.svm.nu_classification, queue) + return self._decision_function(X) diff --git a/onedal/svm/tests/test_csr_svm.py b/onedal/svm/tests/test_csr_svm.py index e4a05a030e..d7da6d404c 100644 --- a/onedal/svm/tests/test_csr_svm.py +++ b/onedal/svm/tests/test_csr_svm.py @@ -74,7 +74,7 @@ def _test_simple_dataset(queue, kernel): check_svm_model_equal(queue, clf0, clf1, *dataset) -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize( "queue", get_queues("cpu") @@ -82,8 +82,7 @@ def _test_simple_dataset(queue, kernel): pytest.param( get_queues("gpu"), marks=pytest.mark.xfail( - reason="raises UnknownError instead of RuntimeError " - "with unimplemented message" + reason="raises UnknownError instead of RuntimeError with unimplemented message" ), ) ], @@ -103,7 +102,7 @@ def _test_binary_dataset(queue, kernel): check_svm_model_equal(queue, clf0, clf1, *dataset) -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize( "queue", get_queues("cpu") @@ -111,9 +110,11 @@ def _test_binary_dataset(queue, kernel): pytest.param( get_queues("gpu"), marks=pytest.mark.xfail( - reason="raises UnknownError for linear and rbf, " - "Unimplemented error with inconsistent error message " - "for poly and sigmoid" + reason=( + "raises UnknownError for linear and rbf, " + "Unimplemented error with inconsistent error message " + "for poly and sigmoid" + ) ), ) ], @@ -138,7 +139,7 @@ def _test_iris(queue, kernel): check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2) -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) def test_iris(queue, kernel): @@ -158,7 +159,7 @@ def _test_diabetes(queue, kernel): check_svm_model_equal(queue, clf0, clf1, *dataset) -@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) def test_diabetes(queue, kernel): diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py index c8bf99a9d3..29e8d2272f 100644 --- a/onedal/svm/tests/test_nusvc.py +++ b/onedal/svm/tests/test_nusvc.py @@ -44,7 +44,7 @@ def _test_libsvm_parameters(queue, array_constr, dtype): assert_array_equal(clf.predict(X, queue=queue), y) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("array_constr", [np.array]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @@ -52,7 +52,7 @@ def test_libsvm_parameters(queue, array_constr, dtype): _test_libsvm_parameters(queue, array_constr, dtype) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_class_weight(queue): X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) @@ -63,7 +63,7 @@ def test_class_weight(queue): assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_sample_weight(queue): X = np.array([[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 2]]) @@ -74,7 +74,7 @@ def test_sample_weight(queue): assert_array_almost_equal(clf.intercept_, [0.0]) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_decision_function(queue): X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] @@ -88,7 +88,7 @@ def test_decision_function(queue): assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue)) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_iris(queue): iris = datasets.load_iris() @@ -97,7 +97,7 @@ def test_iris(queue): assert_array_equal(clf.classes_, np.sort(clf.classes_)) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_decision_function_shape(queue): X, y = make_blobs(n_samples=80, centers=5, random_state=0) @@ -114,7 +114,7 @@ def test_decision_function_shape(queue): # SVC(decision_function_shape='bad').fit(X_train, y_train) -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) def test_pickle(queue): iris = datasets.load_iris() @@ -146,7 +146,7 @@ def _test_cancer_rbf_compare_with_sklearn(queue, nu, gamma): assert abs(result - expected) < 1e-4 -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("gamma", ["scale", "auto"]) @pytest.mark.parametrize("nu", [0.25, 0.5]) @@ -169,7 +169,7 @@ def _test_cancer_linear_compare_with_sklearn(queue, nu): assert abs(result - expected) < 1e-3 -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("nu", [0.25, 0.5]) def test_cancer_linear_compare_with_sklearn(queue, nu): @@ -191,7 +191,7 @@ def _test_cancer_poly_compare_with_sklearn(queue, params): assert abs(result - expected) < 1e-4 -@pass_if_not_implemented_for_gpu(reason="nusvc is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize( "params", diff --git a/onedal/svm/tests/test_nusvr.py b/onedal/svm/tests/test_nusvr.py index 1bec991961..6bcc04e9f4 100644 --- a/onedal/svm/tests/test_nusvr.py +++ b/onedal/svm/tests/test_nusvr.py @@ -30,7 +30,7 @@ synth_params = {"n_samples": 500, "n_features": 100, "random_state": 42} -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_diabetes_simple(queue): diabetes = datasets.load_diabetes() @@ -39,7 +39,7 @@ def test_diabetes_simple(queue): assert clf.score(diabetes.data, diabetes.target, queue=queue) > 0.02 -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_input_format_for_diabetes(queue): diabetes = datasets.load_diabetes() @@ -67,7 +67,7 @@ def test_input_format_for_diabetes(queue): assert_allclose(res_c_contiguous_numpy, res_f_contiguous_numpy) -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_predict(queue): iris = datasets.load_iris() @@ -105,7 +105,7 @@ def _test_diabetes_compare_with_sklearn(queue, kernel): assert_allclose(clf_sklearn.dual_coef_, clf_onedal.dual_coef_, atol=1e-2) -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) def test_diabetes_compare_with_sklearn(queue, kernel): @@ -129,7 +129,7 @@ def _test_synth_rbf_compare_with_sklearn(queue, C, nu, gamma): assert abs(result - expected) < 1e-3 -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("gamma", ["scale", "auto"]) @pytest.mark.parametrize("C", [100.0, 1000.0]) @@ -155,7 +155,7 @@ def _test_synth_linear_compare_with_sklearn(queue, C, nu): assert abs(result - expected) < 1e-3 -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("C", [0.001, 0.1]) @pytest.mark.parametrize("nu", [0.25, 0.75]) @@ -178,7 +178,7 @@ def _test_synth_poly_compare_with_sklearn(queue, params): assert abs(result - expected) < 1e-3 -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize( "params", @@ -191,7 +191,7 @@ def test_synth_poly_compare_with_sklearn(queue, params): _test_synth_poly_compare_with_sklearn(queue, params) -@pass_if_not_implemented_for_gpu(reason="nusvr is not implemented") +@pass_if_not_implemented_for_gpu(reason="not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_pickle(queue): diabetes = datasets.load_diabetes() diff --git a/onedal/svm/tests/test_svr.py b/onedal/svm/tests/test_svr.py index a9000ff5f7..8432fb09b3 100644 --- a/onedal/svm/tests/test_svr.py +++ b/onedal/svm/tests/test_svr.py @@ -30,7 +30,7 @@ synth_params = {"n_samples": 500, "n_features": 100, "random_state": 42} -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_run_to_run_fit(queue): diabetes = datasets.load_diabetes() @@ -45,7 +45,7 @@ def test_run_to_run_fit(queue): assert_allclose(clf_first.dual_coef_, clf.dual_coef_) -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_diabetes_simple(queue): diabetes = datasets.load_diabetes() @@ -54,7 +54,7 @@ def test_diabetes_simple(queue): assert clf.score(diabetes.data, diabetes.target, queue=queue) > 0.02 -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_input_format_for_diabetes(queue): diabetes = datasets.load_diabetes() @@ -82,7 +82,7 @@ def test_input_format_for_diabetes(queue): assert_allclose(res_c_contiguous_numpy, res_f_contiguous_numpy) -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_predict(queue): iris = datasets.load_iris() @@ -120,7 +120,7 @@ def _test_diabetes_compare_with_sklearn(queue, kernel): assert_allclose(clf_sklearn.dual_coef_, clf_onedal.dual_coef_, atol=1e-1) -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"]) def test_diabetes_compare_with_sklearn(queue, kernel): @@ -143,7 +143,7 @@ def _test_synth_rbf_compare_with_sklearn(queue, C, gamma): assert result > expected - 1e-5 -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("gamma", ["scale", "auto"]) @pytest.mark.parametrize("C", [100.0, 1000.0]) @@ -167,7 +167,7 @@ def _test_synth_linear_compare_with_sklearn(queue, C): assert result > expected - 1e-3 -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("C", [0.001, 0.1]) def test_synth_linear_compare_with_sklearn(queue, C): @@ -188,7 +188,7 @@ def _test_synth_poly_compare_with_sklearn(queue, params): assert result > expected - 1e-5 -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize( "params", @@ -201,7 +201,7 @@ def test_synth_poly_compare_with_sklearn(queue, params): _test_synth_poly_compare_with_sklearn(queue, params) -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_sided_sample_weight(queue): clf = SVR(C=1e-2, kernel="linear") @@ -225,7 +225,7 @@ def test_sided_sample_weight(queue): assert y_pred == pytest.approx(1.5) -@pass_if_not_implemented_for_gpu(reason="svr is not implemented") +@pass_if_not_implemented_for_gpu(reason="Regression SVM is not implemented for GPU") @pytest.mark.parametrize("queue", get_queues()) def test_pickle(queue): diabetes = datasets.load_diabetes() diff --git a/onedal/tests/utils/_device_selection.py b/onedal/tests/utils/_device_selection.py index f1b29ab3b9..bdbe27d4eb 100644 --- a/onedal/tests/utils/_device_selection.py +++ b/onedal/tests/utils/_device_selection.py @@ -73,18 +73,6 @@ def is_dpctl_device_available(targets): return False -def device_type_to_str(queue): - if queue is None: - return "cpu" - - if dpctl_available: - if queue.sycl_device.is_cpu: - return "cpu" - if queue.sycl_device.is_gpu: - return "gpu" - return "unknown" - - def pass_if_not_implemented_for_gpu(reason=""): assert reason @@ -92,7 +80,7 @@ def decorator(test): @functools.wraps(test) def wrapper(queue, *args, **kwargs): if queue is not None and queue.sycl_device.is_gpu: - with pytest.raises(RuntimeError, match="is not implemented for GPU"): + with pytest.raises(RuntimeError, match=reason): test(queue, *args, **kwargs) else: test(queue, *args, **kwargs) diff --git a/onedal/utils/__init__.py b/onedal/utils/__init__.py deleted file mode 100644 index 0a1b05fbc2..0000000000 --- a/onedal/utils/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# ============================================================================== -# Copyright 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from .validation import ( - _check_array, - _check_classification_targets, - _check_n_features, - _check_X_y, - _column_or_1d, - _is_arraylike, - _is_arraylike_not_scalar, - _is_csr, - _is_integral_float, - _is_multilabel, - _num_features, - _num_samples, - _type_of_target, - _validate_targets, -) - -__all__ = [ - "_column_or_1d", - "_validate_targets", - "_check_X_y", - "_check_array", - "_check_classification_targets", - "_type_of_target", - "_is_integral_float", - "_is_multilabel", - "_check_n_features", - "_num_features", - "_num_samples", - "_is_arraylike", - "_is_arraylike_not_scalar", - "_is_csr", -] diff --git a/onedal/utils/validation.py b/onedal/utils/validation.py index 8d3fdaa2e4..9049767d4a 100644 --- a/onedal/utils/validation.py +++ b/onedal/utils/validation.py @@ -21,6 +21,9 @@ import numpy as np from scipy import sparse as sp +from onedal._device_offload import supports_queue +from onedal.common._backend import BackendFunction + if np.lib.NumpyVersion(np.__version__) >= np.lib.NumpyVersion("2.0.0a0"): # numpy_version >= 2.0 from numpy.exceptions import VisibleDeprecationWarning @@ -34,8 +37,7 @@ from daal4py.sklearn.utils.validation import ( _assert_all_finite as _daal4py_assert_all_finite, ) -from onedal import _backend -from onedal.common._policy import _get_policy +from onedal import _default_backend as backend from onedal.datatypes import to_table @@ -437,25 +439,29 @@ def _is_csr(x): def _assert_all_finite(X, allow_nan=False, input_name=""): - policy = _get_policy(None, X) + backend_method = BackendFunction( + backend.finiteness_checker.compute.compute, backend, "compute", no_policy=False + ) X_t = to_table(X) params = { "fptype": X_t.dtype, "method": "dense", "allow_nan": allow_nan, } - if not _backend.finiteness_checker.compute.compute(policy, params, X_t).finite: + if not backend_method(params, X_t).finite: type_err = "infinity" if allow_nan else "NaN, infinity" padded_input_name = input_name + " " if input_name else "" msg_err = f"Input {padded_input_name}contains {type_err}." raise ValueError(msg_err) +@supports_queue def assert_all_finite( X, *, allow_nan=False, input_name="", + queue=None, ): _assert_all_finite( X.data if sp.issparse(X) else X, diff --git a/pyproject.toml b/pyproject.toml index 3255e3fa58..5f3df61b24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,13 @@ [tool.black] line-length = 90 target-version = ['py39', 'py310', 'py311', 'py312'] -extend-ignore = 'E203' [tool.isort] profile = "black" line_length = 90 + +[tool.pytest.ini_options] +log_cli = false +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" diff --git a/sklearnex/__init__.py b/sklearnex/__init__.py index 677e681f8d..50b325ed08 100755 --- a/sklearnex/__init__.py +++ b/sklearnex/__init__.py @@ -54,10 +54,10 @@ ] onedal_iface_flag = os.environ.get("OFF_ONEDAL_IFACE", "0") if onedal_iface_flag == "0": - from onedal import _is_spmd_backend + from onedal import _spmd_backend from onedal.common.hyperparameters import get_hyperparameters - if _is_spmd_backend: + if _spmd_backend is not None: __all__.append("spmd") diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 7e299f07e0..094542cea7 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -16,7 +16,7 @@ from functools import wraps -from onedal._device_offload import _copy_to_usm, _get_global_queue, _transfer_to_host +from onedal._device_offload import SyclQueueManager, _copy_to_usm, _transfer_to_host from onedal.utils._array_api import _asarray from onedal.utils._dpep_helpers import dpnp_available @@ -27,70 +27,72 @@ from ._config import get_config -def _get_backend(obj, queue, method_name, *data): - cpu_device = queue is None or queue.sycl_device.is_cpu - gpu_device = queue is not None and queue.sycl_device.is_gpu - - if cpu_device: - patching_status = obj._onedal_cpu_supported(method_name, *data) - if patching_status.get_status(): - return "onedal", queue, patching_status - else: - return "sklearn", None, patching_status - - allow_fallback_to_host = get_config()["allow_fallback_to_host"] - - if gpu_device: - patching_status = obj._onedal_gpu_supported(method_name, *data) - if patching_status.get_status(): - return "onedal", queue, patching_status - else: - if allow_fallback_to_host: - patching_status = obj._onedal_cpu_supported(method_name, *data) - if patching_status.get_status(): - return "onedal", None, patching_status - else: - return "sklearn", None, patching_status +def _get_backend(obj, method_name, *data): + with SyclQueueManager.manage_global_queue(None, *data) as queue: + cpu_device = queue is None or getattr(queue.sycl_device, "is_cpu", True) + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) + + if cpu_device: + patching_status = obj._onedal_cpu_supported(method_name, *data) + if patching_status.get_status(): + return "onedal", patching_status else: - return "sklearn", None, patching_status + return "sklearn", patching_status + + allow_fallback_to_host = get_config()["allow_fallback_to_host"] + + if gpu_device: + patching_status = obj._onedal_gpu_supported(method_name, *data) + if patching_status.get_status(): + return "onedal", patching_status + else: + SyclQueueManager.remove_global_queue() + if allow_fallback_to_host: + patching_status = obj._onedal_cpu_supported(method_name, *data) + if patching_status.get_status(): + return "onedal", patching_status + else: + return "sklearn", patching_status + else: + return "sklearn", patching_status raise RuntimeError("Device support is not implemented") def dispatch(obj, method_name, branches, *args, **kwargs): - q = _get_global_queue() - has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) - has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) - hostkwargs = dict(zip(kwargs.keys(), hostvalues)) - - backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs - if backend == "onedal": - # Host args only used before onedal backend call. - # Device will be offloaded when onedal backend will be called. - patching_status.write_log(queue=q, transferred_to_host=False) - return branches[backend](obj, *hostargs, **hostkwargs, queue=q) - if backend == "sklearn": - if ( - "array_api_dispatch" in get_config() - and get_config()["array_api_dispatch"] - and "array_api_support" in obj._get_tags() - and obj._get_tags()["array_api_support"] - and not has_usm_data - ): - # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is - # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, - # except for the linalg module. There is no guarantee that stock scikit-learn will - # work with such input data. The condition will be updated after DPNP.ndarray and - # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance - # of the fallback cases. - # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, - # then raw inputs are used for the fallback. - patching_status.write_log(transferred_to_host=False) - return branches[backend](obj, *args, **kwargs) - else: - patching_status.write_log() - return branches[backend](obj, *hostargs, **hostkwargs) + with SyclQueueManager.manage_global_queue(None, *args) as queue: + has_usm_data_for_args, hostargs = _transfer_to_host(*args) + has_usm_data_for_kwargs, hostvalues = _transfer_to_host(*kwargs.values()) + hostkwargs = dict(zip(kwargs.keys(), hostvalues)) + + backend, patching_status = _get_backend(obj, method_name, *hostargs) + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs + if backend == "onedal": + # Host args only used before onedal backend call. + # Device will be offloaded when onedal backend will be called. + patching_status.write_log(queue=queue, transferred_to_host=False) + return branches[backend](obj, *hostargs, **hostkwargs, queue=queue) + if backend == "sklearn": + if ( + "array_api_dispatch" in get_config() + and get_config()["array_api_dispatch"] + and "array_api_support" in obj._get_tags() + and obj._get_tags()["array_api_support"] + and not has_usm_data + ): + # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is + # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, + # except for the linalg module. There is no guarantee that stock scikit-learn will + # work with such input data. The condition will be updated after DPNP.ndarray and + # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance + # of the fallback cases. + # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, + # then raw inputs are used for the fallback. + patching_status.write_log(transferred_to_host=False) + return branches[backend](obj, *args, **kwargs) + else: + patching_status.write_log() + return branches[backend](obj, *hostargs, **hostkwargs) raise RuntimeError( f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" ) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index da82e3bd82..4f860c9e4c 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -193,7 +193,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): if not hasattr(self, "_onedal_estimator"): self._onedal_estimator = self._onedal_basic_statistics(**onedal_params) - self._onedal_estimator.fit(X, sample_weight, queue) + self._onedal_estimator.fit(X, sample_weight, queue=queue) self._save_attributes() self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1 diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index d1ddcd55dc..d3671e3602 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -186,9 +186,9 @@ def _get_onedal_result_options(self, options): assert isinstance(onedal_options, str) return options - def _onedal_finalize_fit(self, queue=None): + def _onedal_finalize_fit(self): assert hasattr(self, "_onedal_estimator") - self._onedal_estimator.finalize_fit(queue=queue) + self._onedal_estimator.finalize_fit() self._need_to_finalize = False def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=True): @@ -258,7 +258,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): self.n_features_in_ = X.shape[1] - self._onedal_finalize_fit(queue=queue) + self._onedal_finalize_fit() return self diff --git a/sklearnex/cluster/k_means.py b/sklearnex/cluster/k_means.py index 4ba75ca5b8..91eeada386 100644 --- a/sklearnex/cluster/k_means.py +++ b/sklearnex/cluster/k_means.py @@ -36,7 +36,7 @@ from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from onedal.cluster import KMeans as onedal_KMeans - from onedal.utils import _is_csr + from onedal.utils.validation import _is_csr from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain diff --git a/sklearnex/covariance/incremental_covariance.py b/sklearnex/covariance/incremental_covariance.py index 89ed92b601..bbf9744933 100644 --- a/sklearnex/covariance/incremental_covariance.py +++ b/sklearnex/covariance/incremental_covariance.py @@ -145,9 +145,9 @@ def _onedal_supported(self, method_name, *data): ) return patching_status - def _onedal_finalize_fit(self, queue=None): + def _onedal_finalize_fit(self): assert hasattr(self, "_onedal_estimator") - self._onedal_estimator.finalize_fit(queue=queue) + self._onedal_estimator.finalize_fit() self._need_to_finalize = False if not daal_check_version((2024, "P", 400)) and self.assume_centered: @@ -363,7 +363,7 @@ def _onedal_fit(self, X, queue=None): X_batch = X[batch] self._onedal_partial_fit(X_batch, queue=queue, check_input=False) - self._onedal_finalize_fit(queue=queue) + self._onedal_finalize_fit() return self diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 2a04962645..57bf3e08e0 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -56,7 +56,7 @@ from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor from onedal.primitives import get_tree_state_cls, get_tree_state_reg -from onedal.utils import _num_features, _num_samples +from onedal.utils.validation import _num_features, _num_samples from sklearnex import get_hyperparameters from sklearnex._utils import register_hyperparameters diff --git a/sklearnex/linear_model/coordinate_descent.py b/sklearnex/linear_model/coordinate_descent.py index abe594ad29..cd7b2bcdca 100644 --- a/sklearnex/linear_model/coordinate_descent.py +++ b/sklearnex/linear_model/coordinate_descent.py @@ -19,12 +19,12 @@ # Note: `sklearnex.linear_model.ElasticNet` only has functional # sycl GPU support. No GPU device will be offloaded. -ElasticNet.fit = support_input_format(queue_param=False)(ElasticNet.fit) -ElasticNet.predict = support_input_format(queue_param=False)(ElasticNet.predict) -ElasticNet.score = support_input_format(queue_param=False)(ElasticNet.score) +ElasticNet.fit = support_input_format(ElasticNet.fit) +ElasticNet.predict = support_input_format(ElasticNet.predict) +ElasticNet.score = support_input_format(ElasticNet.score) # Note: `sklearnex.linear_model.Lasso` only has functional # sycl GPU support. No GPU device will be offloaded. -Lasso.fit = support_input_format(queue_param=False)(Lasso.fit) -Lasso.predict = support_input_format(queue_param=False)(Lasso.predict) -Lasso.score = support_input_format(queue_param=False)(Lasso.score) +Lasso.fit = support_input_format(Lasso.fit) +Lasso.predict = support_input_format(Lasso.predict) +Lasso.score = support_input_format(Lasso.score) diff --git a/sklearnex/linear_model/incremental_linear.py b/sklearnex/linear_model/incremental_linear.py index db2d6549c0..7127c4ee70 100644 --- a/sklearnex/linear_model/incremental_linear.py +++ b/sklearnex/linear_model/incremental_linear.py @@ -233,10 +233,10 @@ def _onedal_validate_underdetermined(self, n_samples, n_features): if is_underdetermined: raise ValueError("Not enough samples for oneDAL") - def _onedal_finalize_fit(self, queue=None): + def _onedal_finalize_fit(self): assert hasattr(self, "_onedal_estimator") self._onedal_validate_underdetermined(self.n_samples_seen_, self.n_features_in_) - self._onedal_estimator.finalize_fit(queue=queue) + self._onedal_estimator.finalize_fit() self._need_to_finalize = False def _onedal_fit(self, X, y, queue=None): @@ -294,7 +294,7 @@ def _onedal_fit(self, X, y, queue=None): "Only one sample available. You may want to reshape your data array" ) - self._onedal_finalize_fit(queue=queue) + self._onedal_finalize_fit() return self @property diff --git a/sklearnex/linear_model/incremental_ridge.py b/sklearnex/linear_model/incremental_ridge.py index e750491ef9..232e6da8ab 100644 --- a/sklearnex/linear_model/incremental_ridge.py +++ b/sklearnex/linear_model/incremental_ridge.py @@ -137,7 +137,7 @@ def _onedal_predict(self, X, queue=None): assert hasattr(self, "_onedal_estimator") if self._need_to_finalize: self._onedal_finalize_fit() - return self._onedal_estimator.predict(X, queue) + return self._onedal_estimator.predict(X, queue=queue) def _onedal_score(self, X, y, sample_weight=None, queue=None): return r2_score( @@ -177,7 +177,7 @@ def _onedal_partial_fit(self, X, y, check_input=True, queue=None): } if not hasattr(self, "_onedal_estimator"): self._onedal_estimator = self._onedal_incremental_ridge(**onedal_params) - self._onedal_estimator.partial_fit(X, y, queue) + self._onedal_estimator.partial_fit(X, y, queue=queue) self._need_to_finalize = True def _onedal_finalize_fit(self): diff --git a/sklearnex/linear_model/linear.py b/sklearnex/linear_model/linear.py index fb7eca8cf1..4b0a2b7454 100644 --- a/sklearnex/linear_model/linear.py +++ b/sklearnex/linear_model/linear.py @@ -15,7 +15,6 @@ # =============================================================================== import logging -from abc import ABC import numpy as np from sklearn.linear_model import LinearRegression as _sklearn_LinearRegression @@ -37,7 +36,7 @@ from onedal.common.hyperparameters import get_hyperparameters from onedal.linear_model import LinearRegression as onedal_LinearRegression -from onedal.utils import _num_features, _num_samples +from onedal.utils.validation import _num_features, _num_samples if sklearn_check_version("1.6"): from sklearn.utils.validation import validate_data diff --git a/sklearnex/linear_model/logistic_regression.py b/sklearnex/linear_model/logistic_regression.py index 01e944c74f..7af9555cf0 100644 --- a/sklearnex/linear_model/logistic_regression.py +++ b/sklearnex/linear_model/logistic_regression.py @@ -34,7 +34,7 @@ from daal4py.sklearn._utils import sklearn_check_version from daal4py.sklearn.linear_model.logistic_path import daal4py_fit, daal4py_predict from onedal.linear_model import LogisticRegression as onedal_LogisticRegression - from onedal.utils import _num_samples + from onedal.utils.validation import _num_samples from .._config import get_config from .._device_offload import dispatch, wrap_output_data diff --git a/sklearnex/linear_model/ridge.py b/sklearnex/linear_model/ridge.py index 85d6714905..b5c135219c 100644 --- a/sklearnex/linear_model/ridge.py +++ b/sklearnex/linear_model/ridge.py @@ -35,7 +35,7 @@ from sklearn.utils import check_scalar from onedal.linear_model import Ridge as onedal_Ridge - from onedal.utils import _num_features, _num_samples + from onedal.utils.validation import _num_features, _num_samples from .._device_offload import dispatch, wrap_output_data from .._utils import PatchingConditionsChain @@ -383,8 +383,8 @@ def _save_attributes(self): from daal4py.sklearn.linear_model import Ridge from onedal._device_offload import support_input_format - Ridge.fit = support_input_format(queue_param=False)(Ridge.fit) - Ridge.predict = support_input_format(queue_param=False)(Ridge.predict) - Ridge.score = support_input_format(queue_param=False)(Ridge.score) + Ridge.fit = support_input_format(Ridge.fit) + Ridge.predict = support_input_format(Ridge.predict) + Ridge.score = support_input_format(Ridge.score) logging.warning("Ridge requires oneDAL version >= 2024.6 but it was not found") diff --git a/sklearnex/linear_model/tests/test_logreg.py b/sklearnex/linear_model/tests/test_logreg.py index 65c7ea5d0f..6c30760a46 100755 --- a/sklearnex/linear_model/tests/test_logreg.py +++ b/sklearnex/linear_model/tests/test_logreg.py @@ -49,7 +49,7 @@ def test_sklearnex_multiclass_classification(dataframe, queue): from sklearnex.linear_model import LogisticRegression X, y = load_iris(return_X_y=True) - X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue) + X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue=queue) logreg = LogisticRegression(fit_intercept=True, solver="lbfgs", max_iter=200).fit( X_train, y_train @@ -72,7 +72,7 @@ def test_sklearnex_binary_classification(dataframe, queue): from sklearnex.linear_model import LogisticRegression X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue) + X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue=queue) logreg = LogisticRegression(fit_intercept=True, solver="newton-cg", max_iter=100).fit( X_train, y_train diff --git a/sklearnex/manifold/t_sne.py b/sklearnex/manifold/t_sne.py index 0aa8d7df4f..4dea01bc6b 100755 --- a/sklearnex/manifold/t_sne.py +++ b/sklearnex/manifold/t_sne.py @@ -17,5 +17,5 @@ from daal4py.sklearn.manifold import TSNE from onedal._device_offload import support_input_format -TSNE.fit = support_input_format(queue_param=False)(TSNE.fit) -TSNE.fit_transform = support_input_format(queue_param=False)(TSNE.fit_transform) +TSNE.fit = support_input_format(TSNE.fit) +TSNE.fit_transform = support_input_format(TSNE.fit_transform) diff --git a/sklearnex/metrics/pairwise.py b/sklearnex/metrics/pairwise.py index 8ad789dce1..ffcc136e1d 100755 --- a/sklearnex/metrics/pairwise.py +++ b/sklearnex/metrics/pairwise.py @@ -17,6 +17,4 @@ from daal4py.sklearn.metrics import pairwise_distances from onedal._device_offload import support_input_format -pairwise_distances = support_input_format(freefunc=True, queue_param=False)( - pairwise_distances -) +pairwise_distances = support_input_format(pairwise_distances) diff --git a/sklearnex/metrics/ranking.py b/sklearnex/metrics/ranking.py index 753be6d0cd..4a4fdb8d65 100755 --- a/sklearnex/metrics/ranking.py +++ b/sklearnex/metrics/ranking.py @@ -17,4 +17,4 @@ from daal4py.sklearn.metrics import roc_auc_score from onedal._device_offload import support_input_format -roc_auc_score = support_input_format(freefunc=True, queue_param=False)(roc_auc_score) +roc_auc_score = support_input_format(roc_auc_score) diff --git a/sklearnex/model_selection/split.py b/sklearnex/model_selection/split.py index 59153114b9..5ed44c7428 100755 --- a/sklearnex/model_selection/split.py +++ b/sklearnex/model_selection/split.py @@ -17,6 +17,4 @@ from daal4py.sklearn.model_selection import train_test_split from onedal._device_offload import support_input_format -train_test_split = support_input_format(freefunc=True, queue_param=False)( - train_test_split -) +train_test_split = support_input_format(train_test_split) diff --git a/sklearnex/neighbors/_lof.py b/sklearnex/neighbors/_lof.py index 1e42f8db0d..ec2f0c7747 100644 --- a/sklearnex/neighbors/_lof.py +++ b/sklearnex/neighbors/_lof.py @@ -59,7 +59,7 @@ def _onedal_fit(self, X, y, queue=None): if sklearn_check_version("1.2"): self._validate_params() - self._onedal_knn_fit(X, y, queue) + self._onedal_knn_fit(X, y, queue=queue) if self.contamination != "auto": if not (0.0 < self.contamination <= 0.5): diff --git a/sklearnex/neighbors/common.py b/sklearnex/neighbors/common.py index 0ad5a62dd1..3348f06dd1 100644 --- a/sklearnex/neighbors/common.py +++ b/sklearnex/neighbors/common.py @@ -25,7 +25,7 @@ from sklearn.utils.validation import check_is_fitted from daal4py.sklearn._utils import sklearn_check_version -from onedal.utils import _check_array, _num_features, _num_samples +from onedal.utils.validation import _check_array, _num_features, _num_samples from .._utils import PatchingConditionsChain from ..utils._array_api import get_namespace diff --git a/sklearnex/preview/decomposition/incremental_pca.py b/sklearnex/preview/decomposition/incremental_pca.py index fdf13e0817..949ae5ec40 100644 --- a/sklearnex/preview/decomposition/incremental_pca.py +++ b/sklearnex/preview/decomposition/incremental_pca.py @@ -59,7 +59,7 @@ def _onedal_transform(self, X, queue=None): if self._need_to_finalize: self._onedal_finalize_fit() X = check_array(X, dtype=[np.float64, np.float32]) - return self._onedal_estimator.predict(X, queue) + return self._onedal_estimator.predict(X, queue=queue) def _onedal_fit_transform(self, X, queue=None): self._onedal_fit(X, queue) @@ -114,9 +114,9 @@ def _onedal_partial_fit(self, X, check_input=True, queue=None): self._onedal_estimator.partial_fit(X, queue=queue) self._need_to_finalize = True - def _onedal_finalize_fit(self, queue=None): + def _onedal_finalize_fit(self): assert hasattr(self, "_onedal_estimator") - self._onedal_estimator.finalize_fit(queue=queue) + self._onedal_estimator.finalize_fit() self._need_to_finalize = False def _onedal_fit(self, X, queue=None): @@ -147,7 +147,7 @@ def _onedal_fit(self, X, queue=None): X_batch = X[batch] self._onedal_partial_fit(X_batch, queue=queue) - self._onedal_finalize_fit(queue=queue) + self._onedal_finalize_fit() return self diff --git a/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py b/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py index d2b0cc5704..29c5ad8154 100644 --- a/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +++ b/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py @@ -62,7 +62,8 @@ def test_basic_stats_spmd_gold(dataframe, queue): ) # Ensure results of batch algo match spmd - spmd_result = BasicStatistics_SPMD().fit(local_dpt_data) + spmd = BasicStatistics_SPMD() + spmd_result = spmd.fit(local_dpt_data) batch_result = BasicStatistics_Batch().fit(data) for option in options_and_tests: diff --git a/sklearnex/spmd/neighbors/__init__.py b/sklearnex/spmd/neighbors/__init__.py index 8036511d9f..7b1f9f646c 100644 --- a/sklearnex/spmd/neighbors/__init__.py +++ b/sklearnex/spmd/neighbors/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # ============================================================================== -from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors +from onedal.spmd.neighbors import KNeighborsClassifier, KNeighborsRegressor -__all__ = ["KNeighborsClassifier", "KNeighborsRegressor", "NearestNeighbors"] +__all__ = ["KNeighborsClassifier", "KNeighborsRegressor"] diff --git a/sklearnex/spmd/neighbors/neighbors.py b/sklearnex/spmd/neighbors/neighbors.py deleted file mode 100644 index 5b569c8e1f..0000000000 --- a/sklearnex/spmd/neighbors/neighbors.py +++ /dev/null @@ -1,25 +0,0 @@ -# ============================================================================== -# Copyright 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from onedal.spmd.neighbors import ( - KNeighborsClassifier, - KNeighborsRegressor, - NearestNeighbors, -) - -# TODO: -# Currently it uses `onedal` module interface. -# Add sklearnex dispatching. diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index 4b481314ae..bd31336edb 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -26,7 +26,7 @@ from sklearn.preprocessing import LabelEncoder from daal4py.sklearn._utils import sklearn_check_version -from onedal.utils import _check_array, _check_X_y, _column_or_1d +from onedal.utils.validation import _check_array, _check_X_y, _column_or_1d from .._config import config_context, get_config from .._utils import PatchingConditionsChain diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index 2d52a545cf..5520244cb5 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -18,7 +18,6 @@ import logging import os import tracemalloc -import types import warnings from inspect import isclass @@ -29,7 +28,7 @@ from sklearn.datasets import make_classification from sklearn.model_selection import KFold -from onedal import _is_dpc_backend +from onedal import _default_backend as backend from onedal.tests.utils._dataframes_support import ( _convert_to_dataframe, get_dataframes_and_queues, @@ -51,9 +50,6 @@ if dpnp_available: import dpnp -if _is_dpc_backend: - from onedal import _backend - CPU_SKIP_LIST = ( "TSNE", # too slow for using in testing on common data size @@ -149,8 +145,8 @@ def gen_clsf_data(n_samples, n_features, dtype=None): def get_traced_memory(queue=None): - if _is_dpc_backend and queue and queue.sycl_device.is_gpu: - return _backend.get_used_memory(queue) + if backend.is_dpc and queue and queue.sycl_device.is_gpu: + return backend.get_used_memory(queue) else: return tracemalloc.get_traced_memory()[0] @@ -320,7 +316,7 @@ def test_gpu_memory_leaks(estimator, queue, order, data_shape): @pytest.mark.skipif( - not _is_dpc_backend, + not backend.is_dpc, reason="__sycl_usm_array_interface__ support requires DPC backend.", ) @pytest.mark.parametrize( diff --git a/sklearnex/tests/test_monkeypatch.py b/sklearnex/tests/test_monkeypatch.py index 995fab29e2..3a4dfdda8c 100755 --- a/sklearnex/tests/test_monkeypatch.py +++ b/sklearnex/tests/test_monkeypatch.py @@ -42,8 +42,9 @@ def test_monkey_patching(): n = _classes[i][1] class_module = getattr(p, n).__module__ - assert class_module.startswith("daal4py") or class_module.startswith( - "sklearnex" + assert any( + class_module.startswith(prefix) + for prefix in ["daal4py", "sklearnex", "onedal"] ), "Patching has completed with error." for i, _ in enumerate(_tokens): @@ -87,8 +88,9 @@ def test_monkey_patching(): sklearnex.patch_sklearn(t) class_module = getattr(p, n).__module__ - assert class_module.startswith("daal4py") or class_module.startswith( - "sklearnex" + assert any( + class_module.startswith(prefix) + for prefix in ["daal4py", "sklearnex", "onedal"] ), "Patching has completed with error." finally: sklearnex.unpatch_sklearn()