"""Base class for objects of type connectivity."""
import logging
import numpy as np
from scipy.interpolate import interp2d
from vispy import scene
from vispy.scene import visuals
import vispy.visuals.transforms as vist
from .visbrain_obj import VisbrainObject
from ..objects import ConnectObj
from ..io import download_file, is_sc_image_installed
from ..utils import (array2colormap, color2vb, mpl_cmap, normalize,
vpnormalize, vprecenter)
logger = logging.getLogger('visbrain')
[docs]
class TopoObj(VisbrainObject):
"""Create a topoplot object.
Parameters
----------
name : string
The name of the connectivity object.
data : array_like
Array of data of shape (n_channels)
xyz : array_like | None
Array of source's coordinates.
channels : list | None
List of channel names.
system : {'cartesian', 'spherical'}
Coordinate system.
levels : array_like/int | None
The levels at which the isocurve is constructed.
level_colors : string/array_like | 'white'
The color to use when drawing the line. If a list is given, it
must be of shape (Nlev), if an array is given, it must be of
shape (Nlev, ...). and provide one color per level
(rgba, colorname). By default, all levels are whites.
unit : {'degree', 'rad'}
If system is 'spherical', specify if angles are in degrees or radians.
line_color : array_like/string | 'black'
Color of lines for the head, nose and eras.
line_width : float | 4.
Line width for the head, nose and eras.
chan_size : float | 12.
Size of channel names text.
chan_mark_color : array_like/string | 'white'
Color of channel markers.
chan_mark_symbol : string | 'disc'
Symbol to use for markers. Use disc, arrow, ring, clobber, square,
diamond, vbar, hbar, cross, tailed_arrow, x, triangle_up,
triangle_down, and star.
chan_txt_color : array_like/string | 'black'
Color of channel names.
cmap : string | None
Matplotlib colormap (like 'viridis', 'inferno'...).
clim : tuple/list | None
Colorbar limit. Every values under / over clim will
clip.
vmin : float | None
Every values under vmin will have the color defined
using the under parameter.
vmax : float | None
Every values over vmin will have the color defined
using the over parameter.
under : tuple/string | None
Matplotlib color under vmin.
over : tuple/string | None
Matplotlib color over vmax.
transform : VisPy.visuals.transforms | None
VisPy transformation to set to the parent node.
parent : VisPy.parent | None
Line object parent.
verbose : string
Verbosity level.
kw : dict | {}
Optional arguments are used to control the colorbar
(See :class:`ColorbarObj`).
Notes
-----
List of supported shortcuts :
* **s** : save the figure
* **<delete>** : reset camera
"""
###########################################################################
###########################################################################
# BUILT IN
###########################################################################
###########################################################################
[docs]
def __init__(self, name, data, xyz=None, channels=None, system='cartesian',
levels=None, level_colors='white', unit='degree',
line_color='black', line_width=3., chan_size=12.,
chan_offset=(0., 0., 0.), chan_mark_color='white',
chan_mark_symbol='disc', chan_txt_color='black',
cmap='viridis', clim=None, vmin=None, under='gray', vmax=None,
over='red', margin=.05, transform=None, parent=None,
verbose=None, **kw):
"""Init."""
VisbrainObject.__init__(self, name, parent, transform, verbose, **kw)
# ======================== VARIABLES ========================
scale = 800. # fix GL bugs for small plots
pos = np.zeros((1, 3), dtype=np.float32)
# Colors :
line_color = color2vb(line_color)
chan_txt_color = color2vb(chan_txt_color)
self._chan_mark_color = color2vb(chan_mark_color)
self._chan_mark_symbol = chan_mark_symbol
# Disc interpolation :
self._interp = .1
self._pix = 64
csize = int(self._pix / self._interp) if self._interp else self._pix
l = csize / 2 # noqa
# ======================== NODES ========================
# Main topoplot node :
self.node = scene.Node(name='Topoplot', parent=self._node)
self.node.transform = vist.STTransform(scale=[scale] * 3)
# Headset + channels :
self.node_headfull = scene.Node(name='HeadChan', parent=self.node)
# Headset node :
self.node_head = scene.Node(name='Headset', parent=self.node_headfull)
# Channel node :
self.node_chan = scene.Node(name='Channels', parent=self.node_headfull)
self.node_chan.transform = vist.STTransform(translate=(0., 0., -10.))
# Dictionaries :
kw_line = {'width': line_width, 'color': line_color,
'parent': self.node_head, 'antialias': False}
# ======================== PARENT VISUALS ========================
# Main disc :
self.disc = visuals.Image(pos=pos, name='Disc', parent=self.node_head,
interpolation='bilinear')
# ======================== HEAD / NOSE / EAR ========================
# ------------------ HEAD ------------------
# Head visual :
self.head = visuals.Line(pos=pos, name='Head', **kw_line)
# Head circle :
theta = np.arange(0, 2 * np.pi, 0.001)
head = np.full((len(theta), 3), -1., dtype=np.float32)
head[:, 0] = l * (1. + np.cos(theta))
head[:, 1] = l * (1. + np.sin(theta))
self.head.set_data(pos=head)
# ------------------ NOSE ------------------
# Nose visual :
self.nose = visuals.Line(pos=pos, name='Nose', **kw_line)
# Nose data :
wn, hn = csize * 50. / 512., csize * 30. / 512.
nose = np.array([[l - wn, 2 * l - wn, 2.],
[l, 2 * l + hn, 2.],
[l, 2 * l + hn, 2.],
[l + wn, 2 * l - wn, 2.]
])
self.nose.set_data(pos=nose, connect='segments')
# ------------------ EAR ------------------
we, he = csize * 10. / 512., csize * 30. / 512.
ye = l + he * np.sin(theta)
# Ear left data :
self.earL = visuals.Line(pos=pos, name='EarLeft', **kw_line)
# Ear left visual :
ear_l = np.full((len(theta), 3), 3., dtype=np.float32)
ear_l[:, 0] = 2 * l + we * np.cos(theta)
ear_l[:, 1] = ye
self.earL.set_data(pos=ear_l)
# Ear right visual :
self.earR = visuals.Line(pos=pos, name='EarRight', **kw_line)
# Ear right data :
ear_r = np.full((len(theta), 3), 3., dtype=np.float32)
ear_r[:, 0] = 0. + we * np.cos(theta)
ear_r[:, 1] = ye
self.earR.set_data(pos=ear_r)
# ================== CHANNELS ==================
# Channel's markers :
self.chan_markers = visuals.Markers(pos=pos, name='ChanMarkers',
parent=self.node_chan)
# Channel's text :
self.chan_text = visuals.Text(pos=pos, name='ChanText',
parent=self.node_chan, anchor_x='center',
color=chan_txt_color,
font_size=chan_size)
# ================== CAMERA ==================
self.rect = ((-scale / 2) * (1 + margin),
(-scale / 2) * (1 + margin),
scale * (1. + margin),
scale * (1.11 + margin))
# ================== COORDINATES ==================
auto = self._get_channel_coordinates(xyz, channels, system, unit)
if auto:
eucl = np.sqrt(self._xyz[:, 0]**2 + self._xyz[:, 1]**2).max()
self.node_head.transform = vpnormalize(head, dist=2 * eucl)
# Rescale between (-1:1, -1:1) = circle :
circle = vist.STTransform(scale=(.5 / eucl, .5 / eucl, 1.))
self.node_headfull.transform = circle
# Text translation :
tr = np.array([0., .8, 0.]) + np.array(chan_offset)
else:
# Get coordinates of references along the x and y-axis :
ref_x, ref_y = self._get_ref_coordinates()
# Recenter the topoplot :
t = vist.ChainTransform()
t.prepend(vprecenter(head))
# Rescale (-ref_x:ref_x, -ref_y:ref_y) (ref_x != ref_y => ellipse)
coef_x = 2 * ref_x / head[:, 0].max()
coef_y = 2 * ref_y / head[:, 1].max()
t.prepend(vist.STTransform(scale=(coef_x, coef_y, 1.)))
self.node_head.transform = t
# Rescale between (-1:1, -1:1) = circle :
circle = vist.STTransform(scale=(.5 / ref_x, .5 / ref_y, 1.))
self.node_headfull.transform = circle
# Text translation :
tr = np.array([0., .04, 0.]) + np.array(chan_offset)
self.chan_text.transform = vist.STTransform(translate=tr)
# ================== GRID INTERPOLATION ==================
# Interpolation vectors :
x = y = np.arange(0, self._pix, 1)
xnew = ynew = np.arange(0, self._pix, self._interp)
# Grid interpolation function :
def _grid_interpolation(grid):
f = interp2d(x, y, grid, kind='linear')
return f(xnew, ynew)
self._grid_interpolation = _grid_interpolation
self.set_data(data, levels, level_colors, cmap, clim, vmin, under,
vmax, over)
def __len__(self):
"""Return the number of channels."""
return self._nchan
def __bool__(self):
"""Return if coordinates exist."""
return hasattr(self, '_xyz')
def _get_camera(self):
"""Get the most adapted camera."""
cam = scene.cameras.PanZoomCamera(rect=self.rect)
cam.aspect = 1.
return cam
[docs]
def set_data(self, data, levels=None, level_colors='white', cmap='viridis',
clim=None, vmin=None, under='gray', vmax=None, over='red'):
"""Set data to the topoplot.
Parameters
----------
data : array_like
Array of data of shape (n_channels)
levels : array_like/int | None
The levels at which the isocurve is constructed.
level_colors : string/array_like | 'white'
The color to use when drawing the line. If a list is given, it
must be of shape (Nlev), if an array is given, it must be of
shape (Nlev, ...). and provide one color per level
(rgba, colorname). By default, all levels are whites.
cmap : string | None
Matplotlib colormap (like 'viridis', 'inferno'...).
clim : tuple/list | None
Colorbar limit. Every values under / over clim will
clip.
vmin : float | None
Every values under vmin will have the color defined
using the under parameter.
vmax : float | None
Every values over vmin will have the color defined
using the over parameter.
under : tuple/string | None
Matplotlib color under vmin.
over : tuple/string | None
Matplotlib color over vmax.
"""
# ================== XYZ / CHANNELS / DATA ==================
xyz = self._xyz[self._keeponly]
channels = list(np.array(self._channels)[self._keeponly])
data = np.asarray(data, dtype=float).ravel()
if len(data) == len(self):
data = data[self._keeponly]
logger.info(" %i channels detected" % len(channels))
# =================== CHANNELS ===================
# Markers :
radius = normalize(data, 10., 30.)
self.chan_markers.set_data(pos=xyz, size=radius, edge_color='black',
face_color=self._chan_mark_color,
symbol=self._chan_mark_symbol)
# Names :
if channels is not None:
self.chan_text.text = channels
self.chan_text.pos = xyz
# =================== GRID ===================
pos_x, pos_y = xyz[:, 0], xyz[:, 1]
xmin, xmax = pos_x.min(), pos_x.max()
ymin, ymax = pos_y.min(), pos_y.max()
xi = np.linspace(xmin, xmax, self._pix)
yi = np.linspace(ymin, ymax, self._pix)
xh, yi = np.meshgrid(xi, yi)
grid = self._griddata(pos_x, pos_y, data, xh, yi)
# =================== INTERPOLATION ===================
if self._interp is not None:
grid = self._grid_interpolation(grid)
csize = max(self._pix, grid.shape[0])
# Variables :
l = csize / 2 # noqa
y, x = np.ogrid[-l:l, -l:l]
mask = x**2 + y**2 < l**2
nmask = np.invert(mask)
# =================== DISC ===================
# Force min < off-disc values < max :
d_min, d_max = data.min(), data.max()
grid = normalize(grid, d_min, d_max)
clim = (d_min, d_max) if clim is None else clim
self._update_cbar_args(cmap, clim, vmin, vmax, under, over)
grid_color = array2colormap(grid, **self.to_kwargs())
grid_color[nmask, -1] = 0.
# grid[nmask] = d_min
# self.disc.clim = clim
# self.disc.cmap = cmap_to_glsl(limits=(d_min, d_max),
# translucent=(None, d_min),
# **self.to_kwargs())
self.disc.set_data(grid_color)
# =================== LEVELS ===================
if levels is not None:
if isinstance(levels, int):
levels = np.linspace(d_min, d_max, levels)
if isinstance(level_colors, str):
# Get colormaps :
cmaps = mpl_cmap(bool(level_colors.find('_r') + 1))
if level_colors in cmaps:
level_colors = array2colormap(levels, cmap=level_colors)
grid[nmask] = np.inf
is_sc_image_installed(raise_error=True)
self.iso = visuals.Isocurve(data=grid, parent=self.node_head,
levels=levels, color_lev=level_colors,
width=2.)
self.iso.transform = vist.STTransform(translate=(0., 0., -5.))
[docs]
def connect(self, connect, **kwargs):
"""Draw connectivity lines between channels.
Parameters
----------
connect : array_like
A 2D array of connectivity links of shape (n_channels, n_channels).
kwargs : dict | {}
Optional arguments are passed to the `visbrain.objects.ConnectObj`
object.
"""
logger.info(" Connect channels")
self._connect = ConnectObj('ChanConnect', self._xyz, connect,
parent=self.node_chan, **kwargs)
def _get_channel_coordinates(self, xyz, channels, system, unit):
"""Get channel coordinates.
Parameters
----------
xyz : array_like | None
Array of source's coordinates.
channels : list | None
List of channel names.
system : {'cartesian', 'spherical'}
Coordinate system.
unit : string | {'degree', 'rad'}
If system is 'spherical', specify if angles are in degrees or
radians.
"""
# =====================
if (xyz is None) and (channels is None): # Both None
raise ValueError("You must either define sources using the xyz or"
" channels inputs")
elif isinstance(xyz, np.ndarray): # xyz exist
if xyz.shape[1] not in [2, 3]:
raise ValueError("Shape of xyz must be (nchan, 2) or "
"(nchan, 3)")
nchan = xyz.shape[0]
if xyz.shape[1] == 2:
xyz = np.c_[xyz, np.zeros((nchan), dtype=np.float)]
xyz[:, 2] = 1.
keeponly = np.ones((xyz.shape[0],), dtype=bool)
channels = [''] * nchan if channels is None else channels
auto = True
elif (xyz is None) and (channels is not None): # channels exist
if all([isinstance(k, str) for k in channels]):
xyz, keeponly = self._get_coordinates_from_name(channels)
system, unit = 'spherical', 'degree'
auto = False
# Select channels to use :
if any(keeponly):
if not all(keeponly):
ignore = list(np.array(channels)[np.invert(keeponly)])
logger.warning("Ignored channels for topoplot :"
" %s" % ', '.join(ignore))
# ----------- Conversion -----------
if isinstance(xyz, np.ndarray):
if system == 'cartesian':
pass # all good
elif system == 'spherical':
xyz = self._spherical_to_cartesian(xyz, unit)
xyz = self._array_project_radial_to3d(xyz)
self._xyz = xyz
self._channels = channels
self._keeponly = keeponly
self._nchan = len(channels)
return auto
def _get_ref_coordinates(self, x='T4', y='Fpz'):
"""Get cartesian coordinates for electrodes to use as references.
The ELAN software use by default spherical coordinates with T4 as the
extrema for the x-axis and Fpz as the extrema for the y-axis.
Parameters
----------
x : string | 'T4'
Name of the electrode t use as a reference for the x-axis.
y : string | 'Fpz'
Name of the electrode t use as a reference for the y-axis.
"""
ref = self._get_coordinates_from_name([x, y])[0]
ref = self._spherical_to_cartesian(ref, unit='degree')
ref = self._array_project_radial_to3d(ref)
ref_x, ref_y = ref[0, 0], ref[1, 1]
return ref_x, ref_y
@staticmethod
def _get_coordinates_from_name(chan):
"""From the name of the channels, find xyz coordinates.
Parameters
----------
chan : list
List of channel names.
"""
# Load the coordinates template :
path = download_file('eegref.npz', astype='topo')
file = np.load(path)
name_ref, xyz_ref = file['chan'], file['xyz']
keeponly = np.ones((len(chan)), dtype=bool)
# Find and load xyz coordinates :
xyz = np.zeros((len(chan), 3), dtype=np.float32)
for num, k in enumerate(chan):
# Find if the channel is present :
idx = np.where(name_ref == k.lower())[0]
if idx.size:
xyz[num, 0:2] = np.array(xyz_ref[idx[0], :])
else:
keeponly[num] = False
return np.array(xyz), keeponly
@staticmethod
def _spherical_to_cartesian(xyz, unit='rad'):
"""Convert spherical coordinates to cartesian.
Parameters
----------
xyz : array_like
The array of spheric coordinate of shape (N, 3).
unit : {'rad', 'degree'}
Specify the unit angles.
Returns
-------
xyz : array_like
The cartesian coordinates of the angle of shape (N, 3).
"""
# Get theta / phi :
theta, phi = xyz[:, 0], xyz[:, 1]
if unit is 'degree':
np.deg2rad(theta, out=theta)
np.deg2rad(phi, out=phi)
# Get radius :
r = np.sin(theta)
# Get cartesian coordinates :
np.multiply(np.cos(phi), r, out=xyz[:, 0])
np.multiply(np.sin(phi), r, out=xyz[:, 1])
np.cos(theta, xyz[:, 2])
return xyz
@staticmethod
def _griddata(x, y, v, xi, yi):
"""Make griddata."""
xy = x.ravel() + y.ravel() * -1j
d = xy[None, :] * np.ones((len(xy), 1))
d = np.abs(d - d.T)
n = d.shape[0]
d.flat[::n + 1] = 1.
g = (d * d) * (np.log(d) - 1.)
g.flat[::n + 1] = 0.
weights = np.linalg.solve(g, v.ravel())
m, n = xi.shape
zi = np.zeros_like(xi)
xy = xy.T
g = np.empty(xy.shape)
for i in range(m):
for j in range(n):
d = np.abs(xi[i, j] + -1j * yi[i, j] - xy)
mask = np.where(d == 0)[0]
if len(mask):
d[mask] = 1.
np.log(d, out=g)
g -= 1.
g *= d * d
if len(mask):
g[mask] = 0.
zi[i, j] = g.dot(weights)
return zi
@staticmethod
def _array_project_radial_to3d(points_2d):
"""Radial 3d projection."""
points_2d = np.atleast_2d(points_2d)
alphas = np.sqrt(np.sum(points_2d**2, -1))
betas = np.sin(alphas) / alphas
betas[alphas == 0] = 1
x = points_2d[..., 0] * betas
y = points_2d[..., 1] * betas
z = np.cos(alphas)
points_3d = np.asarray([x, y, z]).T
return points_3d