Cope 2.5.0
My personal "standard library" of all the generally useful code I've written for various projects over the years
Loading...
Searching...
No Matches
plotly.py
1"""
2Functions & classes that extend the plotly library
3"""
4__version__ = '1.0.0'
5
6# Tested manually elsewhere
7# TODO: add manual tests here
8def ridgeplot(df:'DataFrame', x:str, y:str=None, z:str=None, dist:float=.5, overlap:float=0, **kwargs) -> go.Figure:
9 """ Create a ridgeplot in plotly.
10 `x` is the name of the column in `df` that specifies the x data.
11 `y` can either be the name of the column of the y axis data, if `df` is in
12 long format, or left unspecified if `df` is in wide format. If in wide
13 format, it's implied that the data in the columns specified by `z` are
14 the y data.
15 `z` is either a list of columns of `y` data (the column names are then `z`)
16 if `df` is in wide format, or the column name in `df` of the labels
17 that say what each sample belongs to, if in long format.
18 `dist` dictates how close each `z` is. 0 means the max value of each `z`
19 is the min value of the one above it, and 1 means the min value of each
20 `z` is the start of the shading of the `z` above it. Increasing this
21 also makes each `z` taller to compensate.
22 `overlap` effects the size of the shading directly. Usually between -1
23 and 1.
24 """
25 import polars as pl
26 import numpy as np
27 import plotly.graph_objects as go
28
29 # The idea behind this ridgeline plot with Plotly is to add traces manually
30 assert z is not None
31 # Cast to polars, in case pandas was given
32 if not isinstance(df, pl.DataFrame):
33 df = pl.DataFrame(df)
34
35 if y is None:
36 assert isinstance(z, (tuple, list))
37 zs = z
38 else:
39 zs = df[z].unique()
40
41 array_dict = {} # instantiating an empty dictionnary
42 if y is None:
43 for i in z:
44 array_dict[f'x_{i}'] = df[x]
45 # we normalize the array (min max normalization)
46 array_dict[f'y_{i}'] = df[i]
47 array_dict[f'y_{i}'] = (array_dict[f'y_{i}'] - array_dict[f'y_{i}'].min()) / (array_dict[f'y_{i}'].max() - array_dict[f'y_{i}'].min())
48 else:
49 for i in df[z].unique():
50 array_dict[f'x_{i}'] = df.filter(pl.col(z) == i)[x]
51 # we normalize the array (min max normalization)
52 array_dict[f'y_{i}'] = df.filter(pl.col(z) == i)[y]
53 array_dict[f'y_{i}'] = (array_dict[f'y_{i}'] - array_dict[f'y_{i}'].min()) / (array_dict[f'y_{i}'].max() - array_dict[f'y_{i}'].min())
54
55 # once all of this is done, we can create a plotly.graph_objects.Figure and add traces with fig.add_trace() method
56 # since we have stored the values and their respective x for each z, we can plot scatterplots (go.Scatter)
57 # we thus iterate over the z's and create a 'blank line' that is placed at y = index, then the corresponding x line
58 fig = go.Figure()
59 for index, _z in enumerate(zs):
60 fig.add_trace(go.Scatter(
61 x=[
62 df[x].min(),
63 df[x].max()
64 ],
65 y=np.full(2, len(zs)-index - overlap),
66 # mode='lines',
67 line_color='white'
68 ))
69
70 fig.add_trace(go.Scatter(
71 x=array_dict[f'x_{_z}'],
72 y=array_dict[f'y_{_z}'] + (len(zs)-index) + dist,
73 fill='tonextx',
74 name=f'{_z}'
75 ))
76
77 # Add the label for this z
78 fig.add_annotation(
79 x=-5,
80 y=len(zs)-index,
81 text=f'{_z}',
82 showarrow=False,
83 yshift=10
84 )
85
86 fig.update_layout(
87 showlegend=False,
88 xaxis_title=x,
89 yaxis_title=y,
90 # These aren't useful, since they're not on the correct scale
91 yaxis_showticklabels=False,
92 **kwargs
93 )
94
95 return fig
go.Figure ridgeplot('DataFrame' df, str x, str y=None, str z=None, float dist=.5, float overlap=0, **kwargs)
Create a ridgeplot in plotly.
Definition: plotly.py:8