# -*- coding: utf-8 -*-
#
# CLASSIX: Fast and explainable clustering based on sorting
#
# MIT License
#
# Copyright (c) 2024 Stefan Güttel, Xinye Chen
import warnings
import os
import copy
import numbers
import collections
import numpy as np
import pandas as pd
from numpy.linalg import norm
from scipy.spatial import distance
from time import time
[docs]
def cython_is_available(verbose=0):
"""Check if CLASSIX is using Cython."""
__cython_type__ = "memoryview"
from . import __enable_cython__
if __enable_cython__:
try:
# %load_ext Cython
# !python3 setup.py build_ext --inplace
import numpy
try: # check if Cython packages are loaded properly
from .aggregate_cm import general_aggregate, pca_aggregate
from .merge_cm import density_merge, distance_merge
# cython with memoryviews
# Typed memoryviews allow efficient access to memory buffers, such as those underlying NumPy arrays, without incurring any Python overhead.
except ModuleNotFoundError:
from .aggregate_c import general_aggregate, pca_aggregate
__cython_type__ = "trivial"
if verbose:
if __cython_type__ == "memoryview":
print("This CLASSIX is using Cython typed memoryviews.")
else:
print("This CLASSIX is not using Cython typed memoryviews.")
return True
except (ModuleNotFoundError, ValueError):
if verbose:
print("CLASSIX is currently not using Cython.")
return False
else:
if verbose:
print("Cython is currently disabled. Please set ``__enable_cython__`` to True to enable Cython.")
return False
[docs]
def loadData(name='vdu_signals'):
"""Load built-in sample data.
Parameters
----------
name: str, {'vdu_signals', 'Iris', 'Dermatology', 'Ecoli', 'Glass',
'Banknote', 'Seeds', 'Phoneme', 'Wine', 'Covid3MC', 'CovidENV'},
default='vdu_signals'
Identifier of the built-in dataset.
Returns
-------
X, y: numpy.ndarray
Data and ground-truth labels (if available).
"""
current_dir, current_filename = os.path.split(__file__)
if not os.path.isdir(os.path.join(current_dir, 'data')):
os.mkdir(os.path.join(current_dir, 'data/'))
if name == 'vdu_signals':
DATA_PATH = os.path.join(current_dir, 'data/vdu_signals.npy')
if not os.path.isfile(DATA_PATH):
get_data(current_dir)
return np.load(DATA_PATH)
if name == 'Iris':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Irirs.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Irirs.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Iris')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Dermatology':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Dermatology.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Dermatology.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Dermatology')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Ecoli':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Ecoli.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Ecoli.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Ecoli')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Glass':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Glass.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Glass.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Glass')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Banknote':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Banknote.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Banknote.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Banknote')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Seeds':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Seeds.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Seeds.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Seeds')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Phoneme':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Phoneme.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Phoneme.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Phoneme')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Wine':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Wine.npy')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Wine.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Wine')
return np.load(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'CovidENV':
DATA_PATH_X = os.path.join(current_dir, 'data/X_CovidENV.pkl')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_CovidENV.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'CovidENV')
return pd.read_pickle(DATA_PATH_X), np.load(DATA_PATH_Y)
if name == 'Covid3MC':
DATA_PATH_X = os.path.join(current_dir, 'data/X_Covid3MC.pkl')
DATA_PATH_Y = os.path.join(current_dir, 'data/y_Covid3MC.npy')
if not os.path.isfile(DATA_PATH_X) or not os.path.isfile(DATA_PATH_Y):
get_data(current_dir, 'Covid3MC')
return pd.read_pickle(DATA_PATH_X), np.load(DATA_PATH_Y)
if name not in ['vdu_signals', 'Iris', 'Dermatology', 'Ecoli', 'Glass',
'Banknote', 'Seeds', 'Phoneme', 'Wine', 'CovidENV', 'Covid3MC']:
warnings.warn("Invalid dataset identifier.")
def get_data(current_dir='', name='vdu_signals'):
"""Download the built-in sample data from the web."""
import requests
if name == 'vdu_signals':
url_parent = "https://github.com/nla-group/classix/raw/master/classix/source/vdu_signals.npy"
vdu_signals = requests.get(url_parent).content
with open(os.path.join(current_dir, 'data/vdu_signals.npy'), 'wb') as handler:
handler.write(vdu_signals)
elif name == 'Iris':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Irirs.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Irirs.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Irirs.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Irirs.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Dermatology':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Dermatology.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Dermatology.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Dermatology.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Dermatology.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Ecoli':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Ecoli.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Ecoli.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Ecoli.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Ecoli.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Glass':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Glass.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Glass.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Glass.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Glass.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Banknote':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Banknote.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Banknote.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Banknote.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Banknote.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Seeds':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Seeds.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Seeds.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Seeds.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Seeds.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Phoneme':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Phoneme.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Phoneme.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Phoneme.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Phoneme.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Wine':
url_parent_x = "https://github.com/nla-group/classix/raw/master/classix/source/X_Wine.npy"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Wine.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Wine.npy'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Wine.npy'), 'wb') as handler:
handler.write(y)
elif name == 'CovidENV':
url_parent_x = "https://github.com/the-null/data/raw/main/X_CovidENV.pkl"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_CovidENV.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_CovidENV.pkl'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_CovidENV.npy'), 'wb') as handler:
handler.write(y)
elif name == 'Covid3MC':
url_parent_x = "https://github.com/the-null/data/raw/main/X_Covid3MC.pkl"
url_parent_y = "https://github.com/nla-group/classix/raw/master/classix/source/y_Covid3MC.npy"
x = requests.get(url_parent_x).content
y = requests.get(url_parent_y).content
with open(os.path.join(current_dir, 'data/X_Covid3MC.pkl'), 'wb') as handler:
handler.write(x)
with open(os.path.join(current_dir, 'data/y_Covid3MC.npy'), 'wb') as handler:
handler.write(y)
class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
"""
# ******************************************** the main wrapper ********************************************
[docs]
class CLASSIX:
"""CLASSIX: Fast and explainable clustering based on sorting.
The main parameters are ``radius`` and ``minPts``.
Parameters
----------
sorting : str, {'pca', 'norm-mean', 'norm-orthant', None},default='pca'
Sorting method used for the aggregation phase.
- 'pca': sort data points by their first principal component
- 'norm-mean': shift data to have zero mean and then sort by 2-norm values
- 'norm-orthant': shift data to positive orthant and then sort by 2-norm values
- None: aggregate the raw data without any sorting
radius : float, default=0.5
Tolerance to control the aggregation. If the distance between a group center
and an object is less than or equal to the tolerance, the object will be allocated
to the group which the group center belongs to. For details, we refer to [1].
group_merging : str, {'density', 'distance', None}, default='distance'
The method for the merging of groups.
- 'distance': two groups are merged if the distance of their group centers is at
most mergeScale*radius (the parameter above).
- 'density': two groups are merged if the density of data points in their intersection
is at least as high the smaller density of both groups. This option uses a disjoint
set structure for the merging.
If group_merging is set to None, the method will return the labels formed by aggregation
as the cluster labels.
minPts : int, default=1
Clusters with fewer than minPts points are classified as abnormal clusters.
The data points in an abnormal cluster will be redistributed to the nearest normal cluster.
When set to 1, no redistribution is performed.
norm : boolean, default=True
Whether to normalize the data associated with the sorting, default as True.
mergeScale : float
Used with distance-clustering; when distance between the two group centers
associated with two distinct groups smaller than mergeScale*radius,
then the two groups merge.
post_alloc : boolean, default=True
Whether to allocate outliers to the closest groups, hence the corresponding clusters.
If False, all outliers will be labeled as -1.
mergeTinyGroups : boolean, default=True
If this is False, the group merging will ignore all groups with < minPts points.
algorithm : str, default='bf'
Algorithm to merge connected groups.
- 'bf': Use brute force routines to speed up the merging of connected groups.
- 'set': Use disjoint set structure to merge connected groups.
verbose : boolean or int, default=1
Whether to print the logs or not.
short_log_form : boolean, default=True
Whether or not to use short log form to truncate the clusters list.
Attributes
----------
groups_ : numpy.ndarray
Groups labels of aggregation.
splist_ : numpy.ndarray
List of group centers formed in the aggregation.
labels_ : numpy.ndarray
Clustering class labels for data objects
group_outliers_ : numpy.ndarray
Indices of outliers (aggregation groups level),
i.e., indices of abnormal groups within the clusters with fewer
data points than minPts points.
clusterSizes_ : array
The cardinality of each cluster.
groupCenters_ : array
The indices for starting point corresponding to original data order.
nrDistComp_ : float
The number of distance computations.
dataScale_ : float
The value of data scaling.
Methods
----------
fit(data):
Cluster data while the parameters of the model will be saved. The labels can be extracted by calling ``self.labels_``.
fit_transform(data):
Cluster data and return labels. The labels can also be extracted by calling ``self.labels_``.
predict(data):
After clustering the in-sample data, predict the out-sample data.
Data will be allocated to the clusters with the nearest starting point in the stage of aggregation. Default values.
gcIndices(ids):
Return the group center (i.e., starting point) location in the data.
explain(index1, index2, ...):
Explain the computed clustering.
The indices index1 and index2 are optional parameters (int) corresponding to the
indices of the data points.
load_group_centers(self):
Load group centers.
load_cluster_centers(self):
Load cluster centers.
getPath(index1, index2, include_dist=False):
Return the indices of connected data points between index1 data and index2 data.
preprocessing(data):
Normalize the data according to the fitted model.
References
----------
[1] X. Chen and S. Güttel. Fast and explainable clustering based on sorting,
https://arxiv.org/abs/2202.01456, 2022.
"""
def __init__(self, sorting="pca", radius=0.5, minPts=1, group_merging="distance", norm=True, mergeScale=1.5,
post_alloc=True, mergeTinyGroups=True, verbose=1, short_log_form=True):
self.__verbose = verbose
self.minPts = int(minPts)
self.sorting = sorting
self.radius = radius
self.group_merging = group_merging
self.mergeScale_ = mergeScale # For distance measure, usually, we do not use this parameter
self.__post_alloc = post_alloc
self.__mergeTinyGroups = mergeTinyGroups
self.__truncate = short_log_form
self.labels_ = None
self._gcIndices = np.frompyfunc(self.gc2ind, 1, 1)
if self.__verbose:
print(self)
from . import __enable_cython__
self.__enable_cython__ = __enable_cython__
self.__enable_aggregate_cython__ = False
import platform
if self.__enable_cython__:
try:
try:
from .aggregate_cm import general_aggregate, pca_aggregate, lm_aggregate
except ModuleNotFoundError:
from .aggregate_c import general_aggregate, pca_aggregate, lm_aggregate
self.__enable_aggregate_cython__ = True
if platform.system() == 'Windows':
from .merge_cm_win import density_merge, distance_merge, distance_merge_mtg
else:
from .merge_cm import density_merge, distance_merge, distance_merge_mtg
except (ModuleNotFoundError, ValueError):
if not self.__enable_aggregate_cython__:
from .aggregate import general_aggregate, pca_aggregate, lm_aggregate
from .merge import density_merge, distance_merge, distance_merge_mtg
warnings.warn("This CLASSIX installation is not using Cython.")
else:
from .aggregate import general_aggregate, pca_aggregate, lm_aggregate
from .merge import density_merge, distance_merge, distance_merge_mtg
warnings.warn("This run of CLASSIX is not using Cython.")
if platform.system() == 'Windows':
if sorting == 'pca':
self._aggregate = pca_aggregate
else:
self._aggregate = general_aggregate
else:
self._aggregate = lm_aggregate
self._density_merge = density_merge
if self.__mergeTinyGroups:
self._distance_merge = distance_merge
else:
self._distance_merge = distance_merge_mtg
[docs]
def fit(self, data):
"""
Cluster the data and return the associated cluster labels.
Parameters
----------
data : numpy.ndarray
The ndarray-like input of shape (n_samples,)
"""
if isinstance(data, pd.core.frame.DataFrame):
self._index_data = data.index
if not isinstance(data, np.ndarray):
data = np.array(data)
if len(data.shape) == 1:
data = data.reshape(-1,1)
self.t1_prepare = time()
if data.dtype != 'float64':
data = data.astype('float64')
if self.sorting == "norm-mean":
self.mu_ = data.mean(axis=0)
self.data = data - self.mu_
self.dataScale_ = self.data.std()
if self.dataScale_ == 0: # prevent zero-division
self.dataScale_ = 1
self.data = self.data / self.dataScale_
elif self.sorting == "pca":
self.mu_ = data.mean(axis=0)
self.data = data - self.mu_ # mean center
rds = norm(self.data, axis=1) # distance of each data point from 0
self.dataScale_ = np.median(rds) # 50% of data points are within that radius
if self.dataScale_ == 0: # prevent zero-division
self.dataScale_ = 1
self.data = self.data / self.dataScale_ # now 50% of data are in unit ball
elif self.sorting == "norm-orthant":
self.mu_ = data.min(axis=0)
self.data = data - self.mu_
self.dataScale_ = self.data.std()
if self.dataScale_ == 0: # prevent zero-division
self.dataScale_ = 1
self.data = self.data / self.dataScale_
else:
self.mu_, self.dataScale_ = 0, 1 # no preprocessing
self.data = (data - self.mu_) / self.dataScale_
self.t1_prepare = time() - self.t1_prepare
self.t2_aggregate = time()
# aggregation
self.groups_, self.splist_, self.nrDistComp_, self.ind, sort_vals, self.data, self.__half_nrm2 = self._aggregate(
data=self.data,
sorting=self.sorting,
tol=self.radius
)
if self.__half_nrm2 is None:
self.__half_nrm2 = np.einsum('ij,ij->i', self.data, self.data) * 0.5
self.splist_ = np.array(self.splist_)
self.t2_aggregate = time() - self.t2_aggregate
self.t3_merge = time()
if self.group_merging is None:
self.inverse_ind = np.argsort(self.ind)
self.labels_ = copy.deepcopy(self.groups_[self.inverse_ind])
elif self.group_merging.lower()=='none':
self.inverse_ind = np.argsort(self.ind)
self.labels_ = copy.deepcopy(self.groups_[self.inverse_ind])
else:
self.labels_ = self.merging(
data=self.data,
agg_labels=self.groups_,
splist=self.splist_,
ind=self.ind, sort_vals=sort_vals,
radius=self.radius,
method=self.group_merging,
minPts=self.minPts
)
self.t3_merge = time() - self.t3_merge
self.__fit__ = True
return self
[docs]
def predict(self, data):
"""
Allocate the data to their nearest clusters.
- data : numpy.ndarray
The ndarray-like input of shape (n_samples,)
Returns
-------
labels : numpy.ndarray
The predicted clustering labels.
"""
if hasattr(self, '__fit__'):
if not hasattr(self, 'label_change'):
if not hasattr(self, 'inverse_ind'):
self.inverse_ind = np.argsort(self.ind)
groups = np.asarray(self.groups_)
self.label_change = dict(zip(groups[self.inverse_ind], self.labels_))
else:
raise NotFittedError("Please use .fit() method first.")
labels = list()
data = self.preprocessing(np.asarray(data))
indices = self.splist_[:,0].astype(int)
splist = self.data[indices]
splabels = np.argmin(distance.cdist(splist, data), axis=0)
labels = [self.label_change[i] for i in splabels]
return labels
[docs]
def merging(self, data, agg_labels, splist, ind, sort_vals, radius=0.5, method="distance", minPts=1):
"""
Merge groups after aggregation.
Parameters
----------
data : numpy.ndarray
The input that is array-like of shape (n_samples,).
agg_labels: list
Groups labels of aggregation.
splist: numpy.ndarray
List formed in the aggregation storing group centers.
ind : numpy.ndarray
Sort values.
radius : float, default=0.5
Tolerance to control the aggregation hence the whole clustering process. For aggregation,
if the distance between a starting point and an object is less than or equal to the tolerance,
the object will be allocated to the group which the starting point belongs to.
method : str
The method for groups merging,
default='distance', other options: 'density', 'mst-distance', and 'scc-distance'.
minPts : int, default=0
The threshold, in the range of [0, infity] to determine the noise degree.
When assign it 0, algorithm won't check noises.
Returns
-------
labels : numpy.ndarray
The clusters labels of the data
"""
if method == 'density':
agg_labels = np.asarray(agg_labels)
labels = copy.deepcopy(agg_labels)
self.merge_groups, self.connected_pairs_ = self._density_merge(data, splist,
radius, sort_vals=sort_vals,
half_nrm2=self.__half_nrm2)
maxid = max(labels) + 1
# after this step, the connected pairs (groups) will be transformed into merged clusters,
for sublabels in self.merge_groups: # some of aggregated groups might be independent which are not included in self.merge_groups
# not labels[sublabels] = maxid !!!
for j in sublabels:
labels[labels == j] = maxid
maxid = maxid + 1
# but the existent clusters may have some very independent clusters which are possibly be "noise" clusters.
# so the next step is extracting the clusters with very rare number of objects as potential "noises".
# we calculate the percentiles of the number of clusters objects. For example, given the dataset size of 100,
# there are 4 clusters, the associated number of objects inside clusters are respectively of 5, 20, 25, 50.
# The 10th percentlie (we set percent=10, noise_mergeScale=0.1) of (5, 20, 25, 50) is 14,
# and we calculate threshold = 100 * noise_mergeScale = 10. Obviously, the first cluster with number of objects 5
# satisfies both condition 5 < 14 and 5 < 10, so we classify the objects inside first cluster as outlier.
# And then we allocate the objects inside the outlier cluster into other closest cluster.
# This method is quite effective at solving the noise arise from small tolerance (radius).
self.old_cluster_count = collections.Counter(labels)
self.t4_minPts = time()
if minPts >= 1:
potential_noise_labels = self.outlier_filter(min_samples=minPts) # calculate the min_samples directly
SIZE_NOISE_LABELS = len(potential_noise_labels)
if SIZE_NOISE_LABELS == len(np.unique(labels)):
warnings.warn(
"Setting of noise related parameters is not correct, degenerate to the method without noises detection.",
DeprecationWarning)
else:
for i in np.unique(potential_noise_labels):
labels[labels == i] = maxid # marked as noises,
# the label number is not included in any of existing labels (maxid).
if SIZE_NOISE_LABELS > 0:
self.clean_index_ = labels != maxid
agln = agg_labels[self.clean_index_]
label_change = dict(zip(agln, labels[self.clean_index_])) # how object change group to cluster.
# allocate the outliers to the corresponding closest cluster.
self.group_outliers_ = np.unique(agg_labels[~self.clean_index_]) # abnormal groups
unique_agln = np.unique(agln)
splist_clean = splist[unique_agln]
if self.__post_alloc:
for nsp in self.group_outliers_:
alloc_class = np.argmin(
np.linalg.norm(data[splist_clean[:, 0].astype(int)] - data[int(splist[nsp, 0])], axis=1, ord=2)
)
labels[agg_labels == nsp] = label_change[unique_agln[alloc_class]]
else:
labels[np.isin(agg_labels, self.group_outliers_)] = -1
# remove noise cluster, avoid connecting two separate to a single cluster
# the label with the maxid is label marked noises
self.t4_minPts = time() - self.t4_minPts
else:
self.__half_nrm2 = self.__half_nrm2[self.splist_[:, 0]]
labels, self.old_cluster_count, SIZE_NOISE_LABELS = self._distance_merge(data=data,
labels=agg_labels,
splist=splist,
radius=radius,
minPts=minPts,
scale=self.mergeScale_,
sort_vals=sort_vals,
half_nrm2=self.__half_nrm2
)
self.inverse_ind = np.argsort(ind)
labels = labels[self.inverse_ind]
if self.__verbose == 1:
nr_old_clust_count = len(self.old_cluster_count)
print("""CLASSIX aggregated the {datalen} data points into {num_group} groups. """.format(datalen=len(data), num_group=splist.shape[0]))
print("""In total, {dist:.0f} distances were computed ({avg:.1f} per data point). """.format(dist=self.nrDistComp_, avg=self.nrDistComp_/len(data)))
print("""The {num_group} groups were merged into {c_size} clusters.""".format(
num_group=splist.shape[0], c_size=nr_old_clust_count))
if nr_old_clust_count > 20:
print("The largest 20 clusters have the following sizes:")
else:
print("The clusters have the following sizes:")
self.pprint_format(self.old_cluster_count, truncate=self.__truncate)
if self.minPts > 1 and SIZE_NOISE_LABELS > 0:
print("As minPts is {minPts}, the number of clusters has been reduced to {r}.".format(
minPts=self.minPts, r=len(np.unique(labels))
))
print("Use the verbose=0 parameter to suppress this info.\nUse the .explain() method to explain the clustering.")
return labels
[docs]
def explain(self, index1=None, index2=None, cmap='jet', showalldata=False, showallgroups=False, showsplist=False, max_colwidth=None, replace_name=None,
plot=False, figsize=(10, 7), figstyle="default", savefig=False, bcolor="#f5f9f9", obj_color="k", width=1.5, obj_msize=160, sp1_color='lime', sp2_color='cyan',
sp_fcolor="tomato", sp_marker="+", sp_size=72, sp_mcolor="k", sp_alpha=0.05, sp_pad=0.5, sp_fontsize=10, sp_bbox=None, sp_cmarker="+", sp_csize=110,
sp_ccolor="crimson", sp_clinewidths=2.7, dp_fcolor="white", dp_alpha=0.5, dp_pad=2, dp_fontsize=10, dp_bbox=None, show_all_grp_circle=False,
show_connected_grp_circle=False, show_obj_grp_circle=True, color="red", connect_color="green", alpha=0.3, cline_width=2, add_arrow=True,
arrow_linestyle="--", arrow_fc="darkslategrey", arrow_ec="k", arrow_linewidth=1, arrow_shrinkA=2, arrow_shrinkB=2, directed_arrow=0,
axis='off', include_dist=False, show_connected_label=True, figname=None, fmt="pdf"):
"""
'self.explain(object/index) # prints an explanation for why a point object1 is in its cluster (or an outlier)
'self.explain(object1/index1, object2/index2) # prints an explanation why object1 and object2 are either in the same or distinct clusters
Here we unify the terminology:
[-] data points
[-] groups (made up of data points, formed by aggregation)
[-] clusters (made up of groups)
Parameters
----------
index1 : int or numpy.ndarray, optional
Input object1 [with index 'index1'] for explanation.
index2 : int or numpy.ndarray, optional
Input object2 [with index 'index2'] for explanation, and compare objects [with indices 'index1' and 'index2'].
cmap : str, default='Set3'
Colormaps for scatter plot.
showalldata : boolean, default=False
Whether or not to show all data points in global view when too many data points for plot.
showallgroups : boolean, default=False
Whether or not to show the start points marker.
showsplist : boolean, default=False
Whether or not to show the group centers information, which include the number of data points (NumPts),
corresponding clusters, and associated coordinates. This only applies to both index1 and index2 are "NULL".
Default as True.
max_colwidth : int, optional
Max width to truncate each column in characters. By default, no limit.
replace_name : str or list, optional
Replace the index with name.
* For example: as for indices 1 and 1300 we have
``classix.explain(1, 1300, plot=False, figstyle="seaborn") # or classix.explain(obj1, obj4)``
The data point 1 is in group 9 and the data point 1300 is in group 8, both of which were merged into cluster #0.
The two groups are connected via groups 9 -> 2 -> 8.
* if we specify the replace name, then the output will be
``classix.explain(1, 1300, replace_name=["Peter Meyer", "Anna Fields"], figstyle="seaborn")``
The data point Peter Meyer is in group 9 and the data point Anna Fields is in group 8, both of which were merged into cluster #0.
The two groups are connected via groups 9 -> 2 -> 8.
plot : boolean, default=False
Determine if visulize the explanation.
figsize : tuple, default=(9, 6)
Determine the size of explain figure.
figstyle : str, default="default"
Determine the style of visualization.
see reference: https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html
savefig : boolean, default=False
Determine if save figure, the figure will be saved in the folder named "img".
bcolor : str, default="#f5f9f9"
Color for figure background.
obj_color : str, default as "k"
Color for the text of data of index1 and index2.
obj_msize : float, optional:
Size for markers for data of index1 and index2.
sp_fcolor : str, default='tomato'
The color marked for group centers text box.
sp_marker : str, default="+"
The marker for the start points.
sp_size : int, default=66
The marker size for the start points.
sp_mcolor : str, default='k'
The color marked for startpoint points scatter marker.
sp_alpha : float, default=0.3
The value setting for transparency of text box for group centers.
sp_pad : int, default=2
The size of text box for group centers.
sp_bbox : dict, optional
Dict with properties for patches.FancyBboxPatch for group centers.
sp_fontsize : int, optional
The fontsize for text marked for group centers.
sp_cmarker : str, default="+"
The marker for the connected group centers.
sp_csize : int, default=100
The marker size for the connected group centers.
sp_ccolor : str, default="crimson"
The marker color for the connected group centers.
sp_clinewidths : str, default=2.5
The marker width for the connected group centers.
dp_fcolor : str, default='white'
The color marked for specified data objects text box.
dp_alpha : float, default=0.5
The value setting for transparency of text box for specified data objects.
dp_pad : int, default=2
The size of text box for specified data objects.
dp_fontsize : int, optional
The fontsize for text marked for specified data objects.
dp_bbox : dict, optional
Dict with properties for patches.FancyBboxPatch for specified data objects.
show_all_grp_circle : bool, default=False
Whether or not to show all groups' periphery within the objects' clusters
(only applies to when data dimension is less than or equal to 2).
show_connected_grp_circle : bool, default=False
Whether or not to show all connected groups' periphery within the objects' clusters
(only applies to when data dimension is less than or equal to 2).
show_obj_grp_circle : bool, default=True
Whether or not to show the groups' periphery of the objects
(only applies to when data dimension is less than or equal to 2).
color : str, default='red'
Color for text of group centers labels in visualization.
alpha : float, default=0.3
Transparency of data points. Scalar or None.
cline_width : float, default=2
Set the patch linewidth of circle for group centers.
add_arrow : bool, default=False
Whether or not add arrows for connected paths.
arrow_linestyle : str, default='--'
Linestyle for arrow.
arrow_fc : str, default='darkslategrey'
Face color for arrow.
arrow_ec : str, default='k'
Edge color for arrow.
arrow_linewidth : float, default=1
Set the linewidth of the arrow edges.
directed_arrow : int, default=0
Whether or not the edges for arrows is directed.
Values at {-1, 0, 1}, 0 refers to undirected, -1 refers to the edge direction opposite to 1.
shrinkA, shrinkB : float, default=2
Shrinking factor of the tail and head of the arrow respectively.
axis : boolean, default=True
Whether or not add x,y axes to plot.
include_dist : boolean, default=False
Whether or not to include distance information to compute the shortest path between objects.
show_connected_label : boolean, default=True
Whether or not to show the named labels of the connected data points, where the named labels are given by pandas dataframe index.
figname : str, optional
Set the figure name for the image to be saved.
fmt : str
Specify the format of the image to be saved, default as 'pdf', other choice: png.
"""
from scipy.sparse.linalg import svds
self.t5_finalize = time()
# -----------------------------second method--------------------------------
if sp_bbox is None:
sp_bbox = dict()
sp_bbox['facecolor'] = sp_fcolor
sp_bbox['alpha'] = sp_alpha
sp_bbox['pad'] = sp_pad
if dp_bbox is None:
dp_bbox = dict()
dp_bbox['facecolor'] = dp_fcolor
dp_bbox['alpha'] = dp_alpha
dp_bbox['pad'] = dp_pad
if hasattr(self, '__fit__'):
groups_ = np.array(self.groups_)
groups_ = groups_[self.inverse_ind]
if not hasattr(self, 'label_change'):
self.label_change = dict(zip(groups_, self.labels_)) # how object change group to cluster.
else:
raise NotFittedError("Please use .fit() method first.")
data = self.data[self.inverse_ind]
data_size = data.shape[0]
feat_dim = data.shape[1]
if not hasattr(self, 'self.sp_to_c_info'): # ensure call PCA and form groups information table only once
if feat_dim > 2:
_U, self._s, self._V = svds(data, k=2, return_singular_vectors=True)
self.x_pca = np.matmul(data, self._V[(-self._s).argsort()].T)
self.s_pca = self.x_pca[self.ind[self.splist_[:, 0]]]
elif feat_dim == 2:
self.x_pca = data.copy()
self.s_pca = self.data[self.splist_[:, 0]]
else: # when data is one-dimensional, no PCA transform
self.x_pca = np.ones((len(data.copy()), 2))
self.x_pca[:, 0] = data[:, 0]
self.s_pca = np.ones((len(self.splist_), 2))
self.s_pca[:, 0] = self.data[self.splist_[:, 0]].reshape(-1)
self.form_starting_point_clusters_table()
if index1 is None and index2 is not None:
raise ValueError("Please enter a valid value for index1.")
# pd.options.display.max_colwidth = colwidth
dash_line = "--------"*5
if index1 is None: # analyze in the general way with a global view
if plot:
self.explain_viz(showalldata=showalldata, alpha=alpha, cmap=cmap, figsize=figsize, showallgroups=showallgroups, figstyle=figstyle, bcolor=bcolor, savefig=savefig,
fontsize=sp_fontsize, bbox=sp_bbox, sp_marker=sp_marker, sp_mcolor=sp_mcolor, width=width, axis=axis, fmt=fmt)
print("CLASSIX clustered {length:.0f} data points with {dim:.0f} features.\n".format(length=data_size, dim=feat_dim) +
"The radius parameter was set to {tol:.2f} and minPts was set to {minPts:.0f}.\n".format(tol=self.radius, minPts=self.minPts) +
"As the provided data was auto-scaled by a factor of 1/{scl:.2f},\npoints within a radius R={tol:.2f}*{scl:.2f}={tolscl:.2f} were grouped together.\n".format(scl=self.dataScale_, tol=self.radius, tolscl=self.dataScale_*self.radius) +
"In total, {dist:.0f} distances were computed ({avg:.1f} per data point).\n".format(dist=self.nrDistComp_, avg=self.nrDistComp_/data_size) +
"This resulted in {groups:.0f} groups, each with a unique group center.\n".format(groups=self.splist_.shape[0]) +
"These {groups:.0f} groups were subsequently merged into {num_clusters:.0f} clusters. ".format(groups=self.splist_.shape[0], num_clusters=len(np.unique(self.labels_)))
)
if showsplist:
print("A list of all group centers is shown below.")
print(dash_line)
print(self.sp_info.to_string(justify='center', index=False, max_colwidth=max_colwidth))
print(dash_line)
else:
if plot:
print("\nTo explain the clustering of individual data points, use\n " +
".explain(index1) or .explain(index1,index2) with data indices.")
else:
print("\nFor a visualisation of the clusters, use .explain(plot=True).\n" +
"To explain the clustering of individual data points, use\n" +
".explain(index1) or .explain(index1,index2) with data indices.")
else: # index is not None, explain(index1)
if isinstance(index1, numbers.Integral) or isinstance(index1, float):
index1_id, index1 = int(index1), int(index1)
object1 = self.x_pca[index1_id] # data has been normalized
agg_label1 = groups_[index1_id] # get the group index for object1
elif isinstance(index1, str):
if hasattr(self, '_index_data'):
if index1 in self._index_data:
index1_id = np.where(self._index_data == index1)[0][0]
if len(set(self._index_data)) != len(self._index_data):
warnings.warn("The data index contains duplicates.") # SG: Can this even happen with dataframes?
object1 = self.x_pca[index1_id]
agg_label1 = groups_[index1_id]
else:
object1 = self.x_pca[index1_id]
agg_label1 = groups_[index1_id]
else:
raise ValueError("Please use a valid value for index1.")
else:
raise ValueError("Please use a valid value for index1.")
elif isinstance(index1, list) or isinstance(index1, np.ndarray):
index1_id = -1
index1 = np.array(index1)
object1 = (index1 - self.mu_) / self.dataScale_ # allow for out-sample data
if feat_dim > 2:
object1 = np.matmul(object1, self._V[np.argsort(self._s)].T)
agg_label1 = np.argmin(np.linalg.norm(self.s_pca - object1, axis=1, ord=2)) # get the group index for object1
else:
raise ValueError("Please use a valid value for index1.")
# explain one object
if index2 is None:
if replace_name is not None:
if isinstance(replace_name, list):
index1 = replace_name[0]
else:
index1 = replace_name
else:
index1 = index1
cluster_label1 = self.label_change[agg_label1]
if plot:
from matplotlib import pyplot as plt
if self.x_pca.shape[0] > 1e5 and not showalldata:
print("Too many data points for plot. Randomly subsampled 1e5 points.")
selectInd = np.random.choice(self.x_pca.shape[0], 100000, replace=False)
else:
selectInd = np.arange(self.x_pca.shape[0])
if feat_dim > 2:
print("With data having more than two features, the group circles in\nthe plot may appear bigger than they are.")
plt.style.use(style=figstyle)
fig, ax = plt.subplots(figsize=figsize)
ax.set_facecolor(bcolor)
s_pca = self.s_pca[self.sp_info.Cluster == cluster_label1]
ax.scatter(self.x_pca[selectInd, 0], self.x_pca[selectInd, 1], s=60, marker=".", linewidth=0.0*width,
cmap=cmap, alpha=alpha, c=self.labels_[selectInd]
)
ax.scatter(s_pca[:, 0], s_pca[:, 1], marker=sp_marker, label='group centers in cluster #{0}'.format(cluster_label1),
s=sp_size, linewidth=0.9*width, c=sp_mcolor, alpha=0.4)
if feat_dim <= 2 and show_obj_grp_circle:
ax.add_patch(plt.Circle((self.s_pca[agg_label1, 0], self.s_pca[agg_label1, 1]), self.radius, fill=False,
color=sp1_color, alpha=0.5, lw=cline_width*1.5, clip_on=False))
if dp_fontsize is None:
ax.text(object1[0], object1[1], s=' ' + str(index1), bbox=dp_bbox, color=obj_color, zorder=1, ha='left', va='bottom')
else:
ax.text(object1[0], object1[1], s=' ' + str(index1), fontsize=dp_fontsize, bbox=dp_bbox, color=obj_color, zorder=1, ha='left', va='bottom')
if isinstance(index1, str):
ax.scatter(object1[0], object1[1], marker="*", s=obj_msize, label='{} '.format(index1))
else:
ax.scatter(object1[0], object1[1], marker="*", s=obj_msize, label='data point {} '.format(index1))
for i in range(s_pca.shape[0]):
if feat_dim <= 2 and show_all_grp_circle:
ax.add_patch(plt.Circle((s_pca[i, 0], s_pca[i, 1]), self.radius, fill=False, color=color,
alpha=0.5, lw=cline_width*1.5, clip_on=False))
if showallgroups:
if sp_fontsize is None:
ax.text(s_pca[i, 0], s_pca[i, 1],
s=str(self.sp_info.Group[self.sp_info.Cluster == cluster_label1].astype(int).values[i]),
bbox=sp_bbox, zorder=1, ha='left'
)
else:
ax.text(s_pca[i, 0], s_pca[i, 1],
s=str(self.sp_info.Group[self.sp_info.Cluster == cluster_label1].astype(int).values[i]),
fontsize=sp_fontsize, bbox=sp_bbox, zorder=1, ha='left'
)
ax.scatter(self.s_pca[agg_label1, 0], self.s_pca[agg_label1, 1],
marker='.', s=sp_csize*0.3, c=sp1_color, linewidths=sp_clinewidths,
label='group center {0}'.format(agg_label1)
)
ax.set_aspect('equal', adjustable='datalim')
ax.plot()
ax.legend(ncols=3, loc='best') # bbox_to_anchor=(0.5, -0.2)
if axis:
ax.axis('on')
if feat_dim > 1:
ax.set_xlabel("1st principal component")
ax.set_ylabel("2nd principal component")
else:
ax.set_xlabel("1st principal component")
else:
ax.axis('off') # the axis here may not be consistent, so hide.
ax.set_title("""{num_clusters:.0f} clusters (radius={tol:.2f}, minPts={minPts:.0f})""".format(
num_clusters=len(np.unique(self.labels_)),tol=self.radius, minPts=self.minPts))
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
if savefig:
if not os.path.exists("img"):
os.mkdir("img")
if fmt == 'pdf':
if figname is not None:
fm = 'img/' + str(figname) + '.pdf'
else:
fm = 'img/sample.pdf'
plt.savefig(fm, bbox_inches='tight')
elif fmt == 'png':
if figname is not None:
fm = 'img/' + str(figname) + '.png'
else:
fm = 'img/sample.png'
plt.savefig(fm, bbox_inches='tight')
else:
if figname is not None:
fm = 'img/' + str(figname) + '.' + fmt
else:
fm = 'img/sample' + '.' + fmt
print("Image successfully saved as", fm)
plt.show()
if showsplist:
select_sp_info = self.sp_info.iloc[[agg_label1]].copy(deep=True)
select_sp_info.loc[:, 'Label'] = str(np.round(index1,3))
print(dash_line)
print(select_sp_info.to_string(justify='center', index=False, max_colwidth=max_colwidth))
print(dash_line)
print(
"""Data point %(index1)s is in group %(agg_id)i, which was merged into cluster #%(m_c)i."""% {
"index1":index1, "agg_id":agg_label1, "m_c":cluster_label1
}
)
if not plot:
print("Use .explain(..., plot=True) for a visual representation.")
# explain two objects relationship
else:
if isinstance(index2, numbers.Integral) or isinstance(index2, float):
index2_id, index2 = int(index2), int(index2)
object2 = self.x_pca[index2_id] # data has been normalized
agg_label2 = groups_[index2_id] # get the group index for object2
elif isinstance(index2, str):
if hasattr(self, '_index_data'):
if index2 in self._index_data:
index2_id = np.where(self._index_data == index2)[0][0]
if len(set(self._index_data)) != len(self._index_data):
warnings.warn("The data index contains duplicates.") # sg: can this even happen with dataframes?
object2 = self.x_pca[index2_id]
agg_label2 = groups_[index2_id]
else:
object2 = self.x_pca[index2_id]
agg_label2 = groups_[index2_id]
else:
raise ValueError("Please use a valid value for index2.")
else:
raise ValueError("Please use a valid value for index2.")
elif isinstance(index2, list) or isinstance(index2, np.ndarray):
index2_id = -1
index2 = np.array(index2)
object2 = (index2 - self.mu_) / self.dataScale_ # allow for out-sample data
if feat_dim > 2:
object2 = np.matmul(object2, self._V[np.argsort(self._s)].T)
agg_label2 = np.argmin(np.linalg.norm(self.s_pca - object2, axis=1, ord=2)) # get the group index for object2
else:
raise ValueError("Please use a valid value for index2.")
if showsplist:
select_sp_info = self.sp_info.iloc[[agg_label1, agg_label2]].copy(deep=True)
if isinstance(index1, int) or isinstance(index1, str):
select_sp_info.loc[:, 'Label'] = [index1, index2]
else:
select_sp_info.loc[:, 'Label'] = [str(np.round(index1, 3)), str(np.round(index2, 3))]
print(dash_line)
print(select_sp_info.to_string(justify='center', index=False, max_colwidth=max_colwidth))
print(dash_line)
if replace_name is not None:
if isinstance(replace_name, list) or isinstance(replace_name, np.ndarray):
try:
index1 = replace_name[0]
index2 = replace_name[1]
except:
index1 = replace_name[0]
else:
index1 = index1
index2 = index2
cluster_label1, cluster_label2 = self.label_change[agg_label1], self.label_change[agg_label2]
if agg_label1 == agg_label2: # when ind1 & ind2 are in the same group
connected_paths = [agg_label1]
else:
from scipy.sparse import csr_matrix
distm = pairwise_distances(self.data[self.splist_[:, 0]])
distmf = (distm <= self.radius*self.mergeScale_).astype(int)
csr_dist_m = csr_matrix(distmf)
if cluster_label1 == cluster_label2:
connected_paths = find_shortest_dist_path(agg_label1, csr_dist_m, agg_label2, unweighted=not include_dist)
connected_paths.reverse()
if len(connected_paths)<1:
connected_paths_vis = None
else:
connected_paths_vis = " <-> ".join([str(group) for group in connected_paths])
else:
connected_paths = []
if plot:
from matplotlib import pyplot as plt
if self.x_pca.shape[0] > 1e5 and not showalldata:
print("Too many data points for plot. Randomly subsampled 1e5 points.")
selectInd = np.random.choice(self.x_pca.shape[0], 100000, replace=False)
else:
selectInd = np.arange(self.x_pca.shape[0])
if feat_dim > 2:
print("With data having more than two features, the group circles in\nthe plot may appear bigger than they are.")
plt.style.use(style=figstyle)
fig, ax = plt.subplots(figsize=figsize)
ax.set_facecolor(bcolor)
# select indices
union_ind = np.where((self.sp_info.Cluster == cluster_label1) | (self.sp_info.Cluster == cluster_label2))[0]
s_pca = self.s_pca[union_ind]
ax.scatter(self.x_pca[selectInd, 0], self.x_pca[selectInd, 1], s=60, marker=".", c=self.labels_[selectInd], linewidth=0*width, cmap=cmap, alpha=alpha)
ax.scatter(s_pca[:,0], s_pca[:,1], label='group centers', marker=sp_marker, s=sp_size, c=sp_mcolor, linewidth=0.9*width, alpha=0.4)
if feat_dim <= 2 and show_obj_grp_circle:
ax.add_patch(plt.Circle((self.s_pca[agg_label1, 0], self.s_pca[agg_label1, 1]), self.radius, fill=False,
color=sp1_color, alpha=0.5, lw=cline_width*1.5, clip_on=False))
ax.add_patch(plt.Circle((self.s_pca[agg_label2, 0], self.s_pca[agg_label2, 1]), self.radius, fill=False,
color=sp2_color, alpha=0.5, lw=cline_width*1.5, clip_on=False))
if isinstance(index1, int) or isinstance(index1, str):
if dp_fontsize is None:
ax.text(object1[0], object1[1], s=' '+str(index1), ha='left', va='bottom', zorder=1, bbox=dp_bbox, color=obj_color)
ax.text(object2[0], object2[1], s=' '+str(index2), ha='left', va='bottom', zorder=1, bbox=dp_bbox, color=obj_color)
else:
ax.text(object1[0], object1[1], s=' '+str(index1), ha='left', va='bottom', zorder=1, fontsize=dp_fontsize, bbox=dp_bbox, color=obj_color)
ax.text(object2[0], object2[1], s=' '+str(index2), ha='left', va='bottom', zorder=1, fontsize=dp_fontsize, bbox=dp_bbox, color=obj_color)
else:
if dp_fontsize is None:
ax.text(object1[0], object1[1], s=' '+'index 1', ha='left', va='bottom', zorder=1, bbox=dp_bbox, color=obj_color)
ax.text(object2[0], object2[1], s=' '+'index 2', ha='left', va='bottom', zorder=1, bbox=dp_bbox, color=obj_color)
else:
ax.text(object1[0], object1[1], s=' '+'index 1', ha='left', va='bottom', zorder=1, fontsize=dp_fontsize, bbox=dp_bbox, color=obj_color)
ax.text(object2[0], object2[1], s=' '+'index 2', ha='left', va='bottom', zorder=1, fontsize=dp_fontsize, bbox=dp_bbox, color=obj_color)
if isinstance(index1, str):
ax.scatter(object1[0], object1[1], marker="*", s=obj_msize,
label='{} '.format(index1)+'(cluster #{0})'.format(
cluster_label1)
)
else:
ax.scatter(object1[0], object1[1], marker="*", s=obj_msize,
label='data point {} '.format(index1)+'(cluster #{0})'.format(
cluster_label1)
)
if isinstance(index2, str):
ax.scatter(object2[0], object2[1], marker="*", s=obj_msize,
label='{} '.format(index2)+'(cluster #{0})'.format(
cluster_label1)
)
else:
ax.scatter(object2[0], object2[1], marker="*", s=obj_msize,
label='data point {} '.format(index2)+'(cluster #{0})'.format(
cluster_label2)
)
for i in range(s_pca.shape[0]):
if feat_dim <= 2 and show_all_grp_circle:
ax.add_patch(plt.Circle((s_pca[i, 0], s_pca[i, 1]), self.radius, fill=False,
color=color, alpha=0.5, lw=cline_width*1.5, clip_on=False)
)
if showallgroups:
if sp_fontsize is None:
ax.text(s_pca[i, 0], s_pca[i, 1],
s=self.sp_info.Group[
(self.sp_info.Cluster == cluster_label1) | (self.sp_info.Cluster == cluster_label2)
].values[i].astype(int).astype(str),
zorder=1, ha='left', bbox=sp_bbox
)
else:
ax.text(s_pca[i, 0], s_pca[i, 1],
s=self.sp_info.Group[union_ind].values[i].astype(int).astype(str),
fontsize=sp_fontsize, ha='left', bbox=sp_bbox
)
for i in connected_paths:
# draw circle for connected group centers or not,
# and also determine the marker of the connected group centers.
if i == connected_paths[0]:
ax.scatter(self.s_pca[i,0], self.s_pca[i,1], marker=sp_cmarker, s=sp_csize,
label='connected groups', c=sp_ccolor, linewidths=sp_clinewidths)
else:
ax.scatter(self.s_pca[i,0], self.s_pca[i,1], marker=sp_cmarker, s=sp_csize, c=sp_ccolor,
linewidths=sp_clinewidths)
if feat_dim <= 2 and show_connected_grp_circle:
ax.add_patch(plt.Circle((self.s_pca[i, 0], self.s_pca[i, 1]), self.radius, fill=False,
color=connect_color, alpha=0.5, lw=cline_width*1.5, clip_on=False))
ax.scatter(self.s_pca[agg_label1, 0], self.s_pca[agg_label1, 1],
marker='.', s=sp_csize*0.3, c=sp1_color, linewidths=sp_clinewidths,
label='group center {0}'.format(agg_label1)
)
ax.scatter(self.s_pca[agg_label2, 0], self.s_pca[agg_label2, 1],
marker='.', s=sp_csize*0.3, c=sp2_color, linewidths=sp_clinewidths,
label='group center {0}'.format(agg_label2)
)
nr_cps = len(connected_paths)
if add_arrow:
for i in range(nr_cps-1):
arrowStart=(self.s_pca[connected_paths[i], 0], self.s_pca[connected_paths[i], 1])
arrowStop=(self.s_pca[connected_paths[i+1], 0], self.s_pca[connected_paths[i+1], 1])
if directed_arrow == 0:
ax.annotate("", arrowStop,
xytext=arrowStart,
arrowprops=dict(arrowstyle="-|>",
shrinkA=arrow_shrinkA,
shrinkB=arrow_shrinkB,
edgecolor=arrow_fc,
facecolor=arrow_ec,
linestyle=arrow_linestyle,
linewidth=arrow_linewidth
)
)
ax.annotate("", arrowStart,
xytext=arrowStop,
arrowprops=dict(arrowstyle="-|>",
shrinkA=arrow_shrinkA,
shrinkB=arrow_shrinkB,
edgecolor=arrow_fc,
facecolor=arrow_ec,
linestyle=arrow_linestyle,
linewidth=arrow_linewidth
)
)
elif directed_arrow == 1:
ax.annotate("", arrowStop,
xytext=arrowStart,
arrowprops=dict(arrowstyle="-|>",
shrinkA=arrow_shrinkA,
shrinkB=arrow_shrinkB,
edgecolor=arrow_fc,
facecolor=arrow_ec,
linestyle=arrow_linestyle,
linewidth=arrow_linewidth
)
)
else:
ax.annotate("", arrowStart,
xytext=arrowStop,
arrowprops=dict(arrowstyle="-|>",
shrinkA=arrow_shrinkA,
shrinkB=arrow_shrinkB,
edgecolor=arrow_fc,
facecolor=arrow_ec,
linestyle=arrow_linestyle,
linewidth=arrow_linewidth
)
)
if cluster_label1 == cluster_label2 and len(connected_paths) > 1: # change the order of legend
handles, lg_labels = ax.get_legend_handles_labels()
lg_labels = [lg_labels[i] for i in [0,3,1,2,4,5]]
handles = [handles[i] for i in [0,3,1,2,4,5]]
ax.legend(handles, lg_labels, ncols=3, loc='best')
else:
ax.legend(ncols=3, loc='best')
ax.set_aspect('equal', adjustable='datalim')
ax.set_title("""{num_clusters:.0f} clusters (radius={tol:.2f}, minPts={minPts:.0f})""".format(
num_clusters=len(np.unique(self.labels_)),tol=self.radius, minPts=self.minPts))
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
if axis:
ax.axis('on')
if feat_dim > 1:
ax.set_xlabel("1st principal component")
ax.set_ylabel("2nd principal component")
else:
ax.set_xlabel("1st principal component")
else:
ax.axis('off') # the axis here may not be consistent, so hide.
ax.plot()
if savefig:
if not os.path.exists("img"):
os.mkdir("img")
if fmt == 'pdf':
if figname is not None:
fm = 'img/' + str(figname) + '.pdf'
else:
fm = 'img/sample.pdf'
plt.savefig(fm, bbox_inches='tight')
elif fmt == 'png':
if figname is not None:
fm = 'img/' + str(figname) + '.png'
else:
fm = 'img/sample.png'
plt.savefig(fm, bbox_inches='tight')
else:
if figname is not None:
fm = 'img/' + str(figname) + '.' + fmt
else:
fm = 'img/sample' + '.' + fmt
plt.savefig(fm, bbox_inches='tight')
print("image successfully save as", fm)
plt.show()
if agg_label1 == agg_label2: # when ind1 & ind2 are in the same group
print("The data points %(index1)s and %(index2)s are in the same group %(agg_id)i, hence were merged into the same cluster #%(m_c)i"%{
"index1":index1, "index2":index2, "agg_id":agg_label1, "m_c":cluster_label1}
)
else:
if cluster_label1 == cluster_label2:
print(
"""Data point %(index1)s is in group %(agg_id1)s.\nData point %(index2)s is in group %(agg_id2)s.\n"""
"""Both groups were merged into cluster #%(cluster)i. """% {
"index1":index1, "index2":index2, "cluster":cluster_label1, "agg_id1":agg_label1, "agg_id2":agg_label2}
)
if connected_paths_vis is None:
print('No path from group {0} to group {1} with step size <=1.5*R={2:3.2f}.'.format(agg_label1, agg_label2, self.radius*self.mergeScale_))
print('This is because at least one of the groups was reassigned due to the minPts condition.')
else:
print("""\nThe two groups are connected via groups\n %(connected)s.""" % {
"connected":connected_paths_vis}
)
if hasattr(self, '_index_data') and show_connected_label:
show_connected_df = pd.DataFrame(columns=["Index", "Distance", "Group", "Label"])
show_connected_df["Index"] = np.insert(self.gcIndices(connected_paths), [0, len(connected_paths)], [index1_id, index2_id])
consecutive_distances = [distance.euclidean(data[index1_id], data[show_connected_df["Index"].iloc[1]])] + [distm[connected_paths[i],
connected_paths[i+1]] for i in range(len(connected_paths)-1)] + [distance.euclidean(
data[show_connected_df["Index"].iloc[-2]], data[index2_id])]
consecutive_distances = ["{0:5.2f}".format(i*self.dataScale_) for i in consecutive_distances]
show_connected_df.loc[1:, "Distance"] = consecutive_distances
show_connected_df.loc[0, "Distance"] = '--'
show_connected_df["Group"] = [agg_label1] + connected_paths + [agg_label2]
if isinstance(index1, int):
table_index1 = self._index_data[index1]
else:
table_index1 = index1
if isinstance(index2, int):
table_index2 = self._index_data[index2]
else:
table_index2 = index2
show_connected_df["Label"] = [table_index1] + self._index_data[self.gcIndices(connected_paths).astype(int)].tolist() + [table_index2]
else:
show_connected_df = pd.DataFrame(columns=["Index", "Distance", "Group"])
show_connected_df["Index"] = np.insert(self.gcIndices(connected_paths), [0, len(connected_paths)], [index1_id, index2_id])
consecutive_distances = [distance.euclidean(data[index1_id], data[show_connected_df["Index"].iloc[1]])] + [distm[connected_paths[i],
connected_paths[i+1]] for i in range(len(connected_paths)-1)] + [distance.euclidean(
data[show_connected_df["Index"].iloc[-2]], data[index2_id])]
consecutive_distances = ["{0:5.2f}".format(i*self.dataScale_) for i in consecutive_distances]
show_connected_df.loc[1:, "Distance"] = consecutive_distances
show_connected_df.loc[0, "Distance"] = '--'
show_connected_df["Group"] = [agg_label1] + connected_paths + [agg_label2]
print('\nHere is a list of connected data points with\ntheir global data indices and group numbers:\n\n', show_connected_df.to_string(index=False), '\n')
print("""The distance between consecutive data points is at most R={0:0.3n}. """.format(self.radius*self.dataScale_*self.mergeScale_, width=0))
print("""Here, R={0:0.3n}*{1:0.3n}*{2:0.3n}, where {3:0.3n} is the chosen radius parameter, """.format(self.radius, self.dataScale_, self.mergeScale_, self.radius, align='<', width=0))
print("""dataScale_={0:0.3n} is a data scaling factor determined by CLASSIX, """.format(self.dataScale_, width=0))
if self.mergeScale_ == 1.5:
print("""and mergeScale_={0:0.3n} (the default value).""".format(self.mergeScale_))
else:
print("""and mergeScale_={0:0.3n}.""".format(self.mergeScale_))
if not plot:
print("Use .explain(..., plot=True) for a visual representation.")
else:
connected_paths = []
print("""Data point %(index1)s is in group %(agg_id1)i, which was merged into cluster %(c_id1)s.""" % {
"index1":index1, "agg_id1":agg_label1, "c_id1":cluster_label1})
print("""Data point %(index2)s is in group %(agg_id2)i, which was merged into cluster %(c_id2)s.""" % {
"index2":index2, "agg_id2":agg_label2, "c_id2":cluster_label2})
print("""There is no path of overlapping groups between these clusters.""")
self.connected_paths = connected_paths
self.t5_finalize = self.t5_finalize - time()
return
[docs]
def explain_viz(self, showalldata=False, alpha=0.5, cmap='Set3', figsize=(10, 7), showallgroups=False, figstyle="default", bcolor="white", width=0.5, sp_marker="+", sp_mcolor="k",
savefig=False, fontsize=None, bbox=None, axis="off", fmt="pdf"):
"""Visualize the starting point and data points"""
from matplotlib import pyplot as plt
if self.x_pca.shape[0] > 1e5 and not showalldata:
print("Too many data points for plot. Randomly subsampled 1e5 points.")
selectInd = np.random.choice(self.x_pca.shape[0], 100000, replace=False)
else:
selectInd = np.arange(self.x_pca.shape[0])
plt.style.use(style=figstyle)
plt.figure(figsize=figsize)
plt.rcParams['axes.facecolor'] = bcolor
plt.scatter(self.x_pca[selectInd,0], self.x_pca[selectInd,1], s=60, marker=".", linewidth=0*width, c=self.labels_[selectInd], cmap=cmap, alpha=alpha)
if showallgroups:
for j in range(self.s_pca.shape[0]):
if fontsize is None:
plt.text(self.s_pca[j, 0], self.s_pca[j, 1], str(j), zorder=1, ha='left', bbox=bbox)
else:
plt.text(self.s_pca[j, 0], self.s_pca[j, 1], str(j), zorder=1, ha='left', fontsize=fontsize, bbox=bbox)
if showallgroups:
plt.scatter(self.s_pca[:,0], self.s_pca[:,1], label='group centers',
marker=sp_marker, linewidth=0.9*width, c=sp_mcolor)
plt.axis('equal')
plt.title("""{num_clusters:.0f} clusters (radius={tol:.2f}, minPts={minPts:.0f})""".format(
num_clusters=len(np.unique(self.labels_)),tol=self.radius, minPts=self.minPts))
if axis:
plt.axis('on')
if self.s_pca.shape[1] > 1:
plt.xlabel("1st principal component")
plt.ylabel("2nd principal component")
else:
plt.xlabel("1st principal component")
else:
plt.axis('off') # the axis here may not be consistent, so hide.
plt.gca().spines['right'].set_color('none')
plt.gca().spines['top'].set_color('none')
if savefig:
if not os.path.exists("img"):
os.mkdir("img")
if fmt == 'pdf':
fm = 'img/explain_viz.pdf'
plt.savefig(fm, bbox_inches='tight')
elif fmt == 'png':
fm = 'img/explain_viz.png'
plt.savefig(fm, bbox_inches='tight')
else:
fm = 'img/explain_viz.'+fmt
plt.savefig(fm, bbox_inches='tight')
print("image successfully save as", fm)
plt.show()
return
[docs]
def timing(self):
"""
This method will print five timing information regarding classix clustering:
(1) t1_prepare: The initial data preparation, which mainly comprises data scaling and the computation of the first two principal axes.
(2) t2_aggregate: This phase aggregates all data points into groups determined by the radius parameter of CLASSIX.
(3) t3_merge: The computed groups will be merged into clusters when their group centers (starting points) are sufficiently close.
(4) t4_minPts: Clusters with fewer than minPts points will be dissolved into their groups, and each of the groups will then be reassigned to a large enough cluster.
(5) t5_finalize: Any cleanup activities.
"""
if hasattr(self, '__fit__'):
print("t1_prepare:", self.t1_prepare)
print("t2_aggregate:", self.t2_aggregate)
print("t3_merge:", self.t3_merge)
print("t3_merge time:", self.t3_merge)
if hasattr(self, 't5_finalize'):
print("t5_finalize time:", self.t5_finalize)
else:
raise NotFittedError("Please use .fit() method first.")
[docs]
def getPath(self, index1, index2, include_dist=False):
"""
Get the indices of connected data points between index1 data and index2 data.
Parameters
----------
index1 : int
Index for data point.
index2 : int
Index for data point.
Returns
-------
connected_points : numpy.ndarray
connected data points.
"""
from scipy.sparse import csr_matrix
if hasattr(self, '__fit__'):
groups_ = np.array(self.groups_)
groups_ = groups_[self.inverse_ind]
else:
raise NotFittedError("Please use .fit() method first.")
if index1 == index2:
return np.array([index1, index2])
agg_label1 = groups_[index1]
agg_label2 = groups_[index2]
if not include_dist and hasattr(self, 'connected_pairs_'): # precomputed distance
num_nodes = self.splist_.shape[0]
distm = np.full((num_nodes, num_nodes), 0, dtype=int)
for i in range(num_nodes):
distm[i, i] = 0
pairs = np.asarray(self.connected_pairs_, dtype=int)
for pair in pairs:
distm[pair[0], pair[1]] = distm[pair[1], pair[0]] = 1
csr_dist_m = csr_matrix(distm)
connected_paths = find_shortest_dist_path(agg_label1, csr_dist_m, agg_label2, unweighted=include_dist)
connected_paths.reverse()
else:
distm = pairwise_distances(self.data[self.splist_[:, 0]])
distm = (distm <= self.radius*self.mergeScale_).astype(int)
csr_dist_m = csr_matrix(distm)
connected_paths = find_shortest_dist_path(agg_label1, csr_dist_m, agg_label2, unweighted=not include_dist)
connected_paths.reverse()
if len(connected_paths) >= 1:
connected_points = np.insert(self.gcIndices(connected_paths), [0, len(connected_paths)], [index1, index2])
return connected_points
else:
return np.array([])
[docs]
def visualize_linkage(self, scale=1.5, figsize=(10,7), labelsize=24, markersize=320, plot_boundary=False, bound_color='red', path='.', fmt='pdf'):
"""Visualize the linkage in the distance clustering.
Parameters
----------
scale : float
Design for distance-clustering, when distance between the two group centers
associated with two distinct groups smaller than scale*radius, then the two groups merge.
labelsize : int
The fontsize of ticks.
markersize : int
The size of the markers for group centers.
plot_boundary : boolean
If it is true, will plot the boundary of groups for the group centers.
bound_color : str
The color for the boundary for groups with the specified radius.
path : str
Relative file location for figure storage.
fmt : str
Specify the format of the image to be saved, default as 'pdf', other choice: png.
"""
from scipy.sparse import csr_matrix
from matplotlib import pyplot as plt
if not hasattr(self, '__fit__'):
raise NotFittedError("Please use .fit() method first.")
distm, n_components, labels = visualize_connections(self.data, self.splist_, radius=self.radius, scale=round(scale,2))
plt.rcParams['axes.facecolor'] = 'white'
P = self.data[self.splist_[:, 0].astype(int)]
link_list = return_csr_matrix_indices(csr_matrix(distm))
fig, ax = plt.subplots(figsize=figsize)
for i in range(self.splist_.shape[0]):
ax.scatter(P[i,0], P[i,1], s=markersize, c='k', marker='.')
if plot_boundary and self.data.shape[1] <= 2:
ax.add_patch(plt.Circle((P[i, 0], P[i, 1]), self.radius,
color=bound_color, fill=False, clip_on=False)
)
ax.set_aspect('equal', adjustable='datalim')
for edge in link_list:
i, j = edge
ax.plot([P[i, 0], P[j, 0]], [P[i, 1], P[j, 1]], linewidth=3, c='k')
ax.tick_params(axis='both', labelsize=labelsize, colors='k')
if not os.path.isdir(path):
os.makedirs(path)
if fmt == 'pdf':
fig.savefig(path + '/linkage_mergeScale_'+str(round(scale,2))+'_tol_'+str(round(self.radius,2))+'.pdf', bbox_inches='tight')
else:
fig.savefig(path + '/linkage_mergeScale_'+str(round(scale,2))+'_tol_'+str(round(self.radius,2))+'.png', bbox_inches='tight')
[docs]
def preprocessing(self, data):
"""
Normalize the data by the fitted model.
"""
if hasattr(self, '__fit__'):
return (data - self.mu_) / self.dataScale_
else:
raise NotFittedError("Please use .fit() method first.")
@property
def groupCenters_(self):
if hasattr(self, '__fit__'):
return self._gcIndices(np.arange(self.splist_.shape[0]))
else:
raise NotFittedError("Please use .fit() method first.")
@property
def clusterSizes_(self):
if hasattr(self, '__fit__'):
counter = collections.Counter(self.labels_)
return np.array(list(counter.values()))[np.argsort(list(counter.keys()))]
else:
raise NotFittedError("Please use .fit() method first.")
def gcIndices(self, ids):
return self._gcIndices(ids)
def gc2ind(self, spid):
return self.ind[self.splist_[spid, 0]]
[docs]
def load_group_centers(self):
"""Load group centers."""
if not hasattr(self, '__fit__'):
raise NotFittedError("Please use .fit() method first.")
if not hasattr(self, 'grp_centers'):
self.grp_centers = calculate_cluster_centers(self.data, self.groups_)
return self.grp_centers
else:
return self.grp_centers
[docs]
def load_cluster_centers(self):
"""Load cluster centers."""
if not hasattr(self, '__fit__'):
raise NotFittedError("Please use .fit() method first.")
if not hasattr(self, 'centers'):
self.centers = calculate_cluster_centers(self.data[self.inverse_ind], self.labels_)
return self.centers
else:
return self.centers
[docs]
def outlier_filter(self, min_samples=None, min_samples_rate=0.1): # percent
"""Filter outliers in terms of ``min_samples`` or ``min_samples_rate``. """
if min_samples == None:
min_samples = min_samples_rate*sum(self.old_cluster_count.values())
return [i[0] for i in self.old_cluster_count.items() if i[1] < min_samples]
def __repr__(self):
_name = "CLASSIX(radius={0.radius!r}, minPts={0.minPts!r}, group_merging={0.group_merging!r})".format(self)
return _name
def __str__(self):
_name = 'CLASSIX(radius={0.radius!r}, minPts={0.minPts!r}, group_merging={0.group_merging!r})'.format(self)
return _name
@property
def radius(self):
return self._radius
@radius.setter
def radius(self, value):
if not isinstance(value, float) and not isinstance(value,int):
raise TypeError('Expected a float or int type')
if value <= 0:
raise ValueError(
"Please feed an correct value (>0) for tolerance.")
self._radius = value
@property
def sorting(self):
return self._sorting
@sorting.setter
def sorting(self, value):
if not isinstance(value, str) and not isinstance(value, type(None)):
raise TypeError('Expected a string type')
if value not in ['pca', 'norm-mean', 'norm-orthant'] and value != None:
raise ValueError(
"Please refer to an correct sorting way, namely 'pca', 'norm-mean' and 'norm-orthant'.")
self._sorting = value
@property
def group_merging(self):
return self._group_merging
@group_merging.setter
def group_merging(self, value):
if not isinstance(value, str) and not isinstance(value, type(None)):
raise TypeError('Expected a string type or None.')
if value not in ['density',
'distance'
] and value is not None: # 'mst-distance', 'scc-distance', 'trivial-distance', 'trivial-density'
if value.lower()!='none':
raise ValueError(
"Please refer to an correct sorting way, namely 'density' and 'distance' or None."
) # 'scc-distance' and 'mst-distance'
self._group_merging = value
@property
def minPts(self):
return self._minPts
@minPts.setter
def minPts(self, value):
if isinstance(value, str):
raise TypeError('Expected a float or int type.')
if isinstance(value, bool):
raise TypeError('Expected a float or int type.')
if isinstance(value, dict):
raise TypeError('Expected a float or int type.')
if hasattr(value, "__len__"):
raise TypeError('Expected a scalar.')
if value < 0 or (0 < value & value < 1):
raise ValueError('Noise_mergeScale must be 0 or greater than 1.')
self._minPts = int(round(value))
def pairwise_distances(X):
"""Calculate the Euclidean distance matrix."""
return distance.squareform(distance.pdist(X))
def visualize_connections(data, splist, radius=0.5, scale=1.5):
"""Calculate the connected components for graph constructed by group centers given radius and mergeScale."""
from scipy.sparse.csgraph import connected_components
distm = pairwise_distances(data[splist[:,0].astype(int)])
tol = radius*scale
distm = (distm <= tol).astype(int)
n_components, labels = connected_components(csgraph=distm, directed=False, return_labels=True)
return distm, n_components, labels
[docs]
def preprocessing(data, base):
"""Initial data preparation of CLASSIX."""
if base == "norm-mean":
_mu = data.mean(axis=0)
ndata = data - _mu
dataScale = ndata.std()
ndata = ndata / dataScale
elif base == "pca":
_mu = data.mean(axis=0)
ndata = data - _mu # mean center
rds = norm(ndata, axis=1) # distance of each data point from 0
dataScale = np.median(rds) # 50% of data points are within that radius
ndata = ndata / dataScale # now 50% of data are in unit ball
elif base == "norm-orthant":
_mu = data.min(axis=0)
ndata = data - _mu
dataScale = ndata.std()
ndata = ndata / dataScale
else:
_mu, dataScale = 0, 1 # no preprocessing
ndata = (data - _mu) / dataScale
return ndata, (_mu, dataScale)
[docs]
def calculate_cluster_centers(data, labels):
"""Calculate the mean centers of clusters from given data."""
classes = np.unique(labels)
centers = np.zeros((len(classes), data.shape[1]))
for c in classes:
centers[c] = np.mean(data[labels==c,:], axis=0)
return centers
# ##########################################################################################################
# **************<!-- the independent functions of finding shortest path between two objects ***************
# ##########################################################################################################
def find_shortest_dist_path(source_node=None, graph=None, target_node=None, unweighted=True):
""" Get single-sourse shortest paths as well as distance from source node,
design especially for unweighted undirected graph. The time complexity is O(|V| + |E|)
where |V| is the number of vertices and |E| is the number of edges.
Parameters
----------
source_node: int
A given source vertex.
graph : scipy.sparse._csr.csr_matrix
Input as a sparse matrix format.
target_node: int, default=None
Find the shortest paths from source node to target node.
If not None, function returns the shortest path between source node and target node,
otherwise returns table storing shortest path information.
unweighted : bool, default=True
If True, then find unweighted distances, i.e., find the path such that the number of edges is minimized.
Returns
-------
shortest_path_to_target: list
The shortest path between source node and target node
"""
from scipy.sparse.csgraph import shortest_path
dist_matrix, predecessors = shortest_path(csgraph=graph, directed=False, unweighted=unweighted, indices=source_node, return_predecessors=True)
if predecessors[target_node] != -9999:
shortest_path_to_target = []
shortest_path_to_target.append(target_node)
predecessor = predecessors[target_node]
while predecessor != -9999:
shortest_path_to_target.append(predecessor)
predecessor = predecessors[predecessor]
return shortest_path_to_target
else:
return []
def return_csr_matrix_indices(csr_mat):
"""Return sparce matrix indices."""
from scipy.sparse import _sparsetools
shape_dim1, shape_dim2 = csr_mat.shape
length_range = csr_mat.indices
indices = np.empty(len(length_range), dtype=csr_mat.indices.dtype)
_sparsetools.expandptr(shape_dim1, csr_mat.indptr, indices)
return np.array(list(zip(indices, length_range)))
def euclid(xxt, X, v):
return (xxt + np.inner(v,v).ravel() -2*X.dot(v)).astype(float)