[1]:
import os
import sys
import argparse
import pandas as pd
import torch
import anndata as ad
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from DeepRUOT.losses import OT_loss1
from DeepRUOT.utils import (
    generate_steps, load_and_merge_config,
    SchrodingerBridgeConditionalFlowMatcher,
    generate_state_trajectory, get_batch, get_batch_size
)
from DeepRUOT.train import train_un1_reduce, train_all
from DeepRUOT.models import FNet, scoreNet2
from DeepRUOT.constants import DATA_DIR, RES_DIR
from DeepRUOT.exp import setup_exp

Convert adata to csv

[ ]:
import scanpy as sc
# Load your own adata
original_data = sc.read_h5ad('../Weinreb_data.h5ad')
# Assume the data has been preprocessed
# and the dim reduction has been done

# Dim reduction data
X_reduced = original_data.obsm['X_pca']

# Convert time points, note that you need to change your own time key
sample_values = original_data.obs['Time point'].copy()
sample_values = sample_values.apply(lambda x: (x - 2) / 2) # change this according to your own data

n_components = X_reduced.shape[1]
columns = [f'x{i+1}' for i in range(n_components)]

# Create DataFrame
df = pd.DataFrame(
    X_reduced,
    columns=columns
)

# Add Time point column
df.insert(0, 'samples', sample_values.values)

# Save as CSV file for DeepRUOT analysis, and this file should be the same as that in the config file
output_file = '../data/Weinreb_data.csv'
df.to_csv(output_file, index=False)

[3]:
original_data
[3]:
AnnData object with n_obs × n_vars = 49302 × 2447
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone', 'batch'
    var: 'gene'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'

Load config

[4]:
config_path = '../config/weinreb_config.yaml'

# Load and merge configuration
config = load_and_merge_config(config_path)

Load data and model

[5]:
df = pd.read_csv(os.path.join(DATA_DIR, config['data']['file_path']))
df = df.iloc[:, :config['data']['dim'] + 1]
device = torch.device('cpu')
exp_dir, logger = setup_exp(
            RES_DIR,
            config,
            config['exp']['name']
        )
dim = config['data']['dim']
[6]:
model_config = config['model']

f_net = FNet(
    in_out_dim=model_config['in_out_dim'],
    hidden_dim=model_config['hidden_dim'],
    n_hiddens=model_config['n_hiddens'],
    activation=model_config['activation']
).to(device)

sf2m_score_model = scoreNet2(
    in_out_dim=model_config['in_out_dim'],
    hidden_dim=model_config['score_hidden_dim'],
    activation=model_config['activation']
).float().to(device)
[7]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir, 'model_final'),map_location=torch.device('cpu')))
f_net.to(device)
sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, 'score_model_final'),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
[7]:
scoreNet2(
  (activation): LeakyReLU(negative_slope=0.01)
  (net): Sequential(
    (0): Linear(in_features=51, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
)

Dim reduction (Optional)

[8]:
import numpy as np
import joblib
from sklearn.decomposition import PCA  # Import PCA
import matplotlib.pyplot as plt
import seaborn as sns

import umap
# umap_op = PCA(n_components=2)
umap_op = umap.UMAP(n_components=2, random_state=42) # You may change UMAP to PCA or other dimension reduction methods
xu = umap_op.fit_transform(df.iloc[:, 1:])  # Assuming df is your DataFrame
joblib.dump(umap_op, os.path.join(exp_dir, 'dim_reduction.pkl'))  # Save the UMAP model
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
[8]:
['/lustre/home/2501111653/DeepRUOTv2_test_data/results/weinreb_experiment/dim_reduction.pkl']

Plot Velocity

[ ]:
import numpy as np
import torch
import matplotlib.pyplot as plt

device = 'cuda'
f_net.to(device)
sf2m_score_model.to(device)
all_times = df['samples'].values
all_data = df[[f'x{i}' for i in range(1, dim + 1)]].values

t_tensor = torch.tensor(all_times, dtype=torch.float32).unsqueeze(1).to(device)
data_tensor = torch.tensor(all_data, dtype=torch.float32).to(device)

# Calculate velocity of ODE
with torch.no_grad():
    velocity_ode = f_net.v_net(t_tensor, data_tensor)

velocity_ode_np = velocity_ode.cpu().numpy()

data_tensor.requires_grad_(True)  # Enable gradient tracking


# Calculate score
log_density_values = sf2m_score_model(t_tensor, data_tensor)
log_density_values.backward(torch.ones_like(log_density_values))
score = data_tensor.grad
score = score.cpu().numpy()

# Calculate overall drift
if config['score_train']['sigma'] == 0:
    drift = velocity_ode_np
else:
    drift = velocity_ode_np + score
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/torch/autograd/graph.py:829: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:179.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[10]:
import anndata
import scvelo as scv
import scanpy as sc
import numpy as np

# Assume your data is already defined
# all_data: 50-dimensional data with shape (n_cells, 50)
# gradients: 50-dimensional vector field with shape (n_cells, 50)
# pca: Trained PCA or UMAP model for generating X_umap
# X_umap: Pre-computed UMAP embeddings with shape (n_cells, 2)

# Create AnnData object
dim_reducer = joblib.load(os.path.join(exp_dir, 'dim_reduction.pkl')) # If no dim reducer is needed, set this to None
adata = anndata.AnnData(X=all_data)

# Set 'Ms' layer to avoid KeyError: 'Ms'
adata.layers['Ms'] = all_data  # Use original 50D data as state matrix

# Set velocity vectors
adata.layers['velocity'] = drift # Store velocity vectors in layers

# Set pre-computed UMAP embeddings
if dim_reducer is not None:
    X_umap = dim_reducer.transform(all_data)  # Assume pca is trained dim reduction model
else:
    X_umap = all_data[:2]
adata.obsm['X_umap'] = X_umap
adata.obs['time'] = all_times
if adata.layers['velocity'].shape[1] != 2:
    # Compute neighbor graph (required for velocity graph)
    sc.pp.neighbors(adata, n_neighbors=30, use_rep='X')  # Calculate neighbors based on high-dim data

    # Compute velocity graph
    scv.tl.velocity_graph(adata, vkey='velocity', n_jobs=16)  # Build velocity graph from high-dim velocity vectors

    # Project velocities to UMAP space
    scv.tl.velocity_embedding(adata, basis='umap', vkey='velocity')  # Project velocities to UMAP
else:
    adata.obsm['velocity_umap'] = adata.layers['velocity']

# Plot
adata.obs['time_categorical'] = pd.Categorical(adata.obs['time'])

# Visualization settings
scv.settings.set_figure_params('scvelo')  # Set scvelo plotting style
computing velocity graph (using 16/64 cores)
    finished (0:00:22) --> added
    'velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:07) --> added
    'velocity_umap', embedded velocity vectors (adata.obsm)
[11]:
# Optional: load the cell type information

# Need original data to get celltype
original_data = sc.read_h5ad('../Weinreb_data.h5ad')
adata.obs['cell_type'] = original_data.obs['Cell type annotation'].values

# Visualization settings
scv.settings.set_figure_params('scvelo')  # Set scvelo plotting style
[12]:
scv.pl.velocity_embedding_stream(
    adata,
    basis='umap',
    color='cell_type',
    figsize=(7, 5),
    density=3,
    title='Velocity Streamline',
    # legend_loc='right',
    palette='plasma',
    save=exp_dir+'/all_velocity_stream_plot.svg'
)
saving figure to file /lustre/home/2501111653/DeepRUOTv2_test_data/results/weinreb_experiment/all_velocity_stream_plot.svg
../_images/notebook_analysis_16_1.png

Fit Potential

[13]:
# Fit potential on 2D UMAP to create landscape
input = adata.obsm['X_umap']
output = adata.obsm['velocity_umap']
import torch
import torch.nn as nn
import torch.optim as optim

X = torch.tensor(input, dtype=torch.float32)
V = torch.tensor(output.values if hasattr(output, 'values') else output, dtype=torch.float32)

class PotentialNet(nn.Module):
    def __init__(self, in_dim=2, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

potential_net = PotentialNet(in_dim=X.shape[1])
potential_net = potential_net.cuda()
X = X.cuda()
V = V.cuda()

optimizer = optim.Adam(potential_net.parameters(), lr=1e-3)
n_epochs = 2000

for epoch in range(n_epochs):
    optimizer.zero_grad()
    X.requires_grad = True
    phi = potential_net(X)
    grad_phi = torch.autograd.grad(
        phi, X,
        grad_outputs=torch.ones_like(phi),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]  # (N, 2)
    # Velocity is the negative gradient of potential
    pred_V = -grad_phi
    loss = ((pred_V - V)**2).mean()
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.6f}")
Epoch 0, Loss: 0.004214
Epoch 50, Loss: 0.000147
Epoch 100, Loss: 0.000109
Epoch 150, Loss: 0.000092
Epoch 200, Loss: 0.000085
Epoch 250, Loss: 0.000082
Epoch 300, Loss: 0.000079
Epoch 350, Loss: 0.000075
Epoch 400, Loss: 0.000073
Epoch 450, Loss: 0.000072
Epoch 500, Loss: 0.000077
Epoch 550, Loss: 0.000078
Epoch 600, Loss: 0.000075
Epoch 650, Loss: 0.000070
Epoch 700, Loss: 0.000071
Epoch 750, Loss: 0.000072
Epoch 800, Loss: 0.000070
Epoch 850, Loss: 0.000072
Epoch 900, Loss: 0.000070
Epoch 950, Loss: 0.000068
Epoch 1000, Loss: 0.000070
Epoch 1050, Loss: 0.000068
Epoch 1100, Loss: 0.000069
Epoch 1150, Loss: 0.000073
Epoch 1200, Loss: 0.000073
Epoch 1250, Loss: 0.000070
Epoch 1300, Loss: 0.000073
Epoch 1350, Loss: 0.000073
Epoch 1400, Loss: 0.000076
Epoch 1450, Loss: 0.000074
Epoch 1500, Loss: 0.000076
Epoch 1550, Loss: 0.000078
Epoch 1600, Loss: 0.000077
Epoch 1650, Loss: 0.000082
Epoch 1700, Loss: 0.000081
Epoch 1750, Loss: 0.000081
Epoch 1800, Loss: 0.000083
Epoch 1850, Loss: 0.000083
Epoch 1900, Loss: 0.000085
Epoch 1950, Loss: 0.000083
[14]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import torch.nn as nn

device = 'cuda'

# Calculating Potential for original cells
umap_coords = adata.obsm['X_umap']
with torch.no_grad():
    input_tensor = torch.from_numpy(umap_coords).float().to(device)
    original_potentials = potential_net(input_tensor).cpu().numpy().flatten()

# Plot potential
sns.set_style("white")
fig, ax = plt.subplots(figsize=(12, 8))

scatter = sns.scatterplot(
    x=umap_coords[::1, 0],
    y=umap_coords[::1, 1],
    hue=original_potentials[::1],
    palette='RdYlBu_r',
    s=5,
    alpha=0.8,
    edgecolor='none',
    ax=ax,
    legend=False
)


norm = plt.Normalize(original_potentials.min(), original_potentials.max())
sm = plt.cm.ScalarMappable(cmap="RdYlBu_r", norm=norm)
sm.set_array([])

cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('Potential', fontsize=12)


ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)

ax.set_aspect('equal', adjustable='box')
plt.savefig(exp_dir+'/Potential_landscape.png', dpi=300, bbox_inches='tight')
plt.grid(False)

plt.tight_layout()
plt.show()

../_images/notebook_analysis_19_0.png

Calculate Fate Probability

[15]:
import cellrank as cr
vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()

g = cr.estimators.GPCCA(vk)

# Manually define terminal states
manual_terminal_states = {
    "Neutrophil": adata.obs_names[adata.obs["cell_type"] == "Neutrophil"].tolist(),
    "Monocyte": adata.obs_names[adata.obs["cell_type"] == "Monocyte"].tolist(),
    "Meg": adata.obs_names[adata.obs["cell_type"] == "Meg"].tolist(),
    "Mast": adata.obs_names[adata.obs["cell_type"] == "Mast"].tolist(),
    "Baso": adata.obs_names[adata.obs["cell_type"] == "Baso"].tolist(),
}

g.set_terminal_states(manual_terminal_states)

# We can plot to confirm that the terminal states are correctly set
g.plot_macrostates(which="terminal", mode = 'embedding', legend_loc="right margin",)
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/scatter.py:656: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  smp = ax.scatter(
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/scatter.py:694: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/utils.py:1396: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, s=bg_size, marker=".", c=bg_color, zorder=zord - 2, **kwargs)
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/utils.py:1397: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(x, y, s=gp_size, marker=".", c=gp_color, zorder=zord - 1, **kwargs)
../_images/notebook_analysis_21_3.png
[16]:
# Calculate fate probabilities
g.compute_fate_probabilities()
g.plot_fate_probabilities(mode="embedding", title = '', save=exp_dir+'/all_fate_probabilities_plot.svg')
WARNING: Unable to import petsc4py. For installation, please refer to: https://petsc4py.readthedocs.io/en/stable/install.html.
Defaulting to `'gmres'` solver.
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/scatter.py:656: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  smp = ax.scatter(
/lustre/home/2501111653/miniconda3/envs/DeepRUOTv2/lib/python3.10/site-packages/scvelo/plotting/scatter.py:656: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  smp = ax.scatter(
saving figure to file /lustre/home/2501111653/DeepRUOTv2_test_data/results/weinreb_experiment/all_fate_probabilities_plot.svg
../_images/notebook_analysis_22_4.png

Select transition genes

[17]:
# Project the velocity back to the gene expression space
pca_components = original_data.varm['PCs'].T
v_ori = drift @ pca_components
top_100_idx = np.argsort(v_ori.mean(axis=0))[-100:][::-1]
# Get the gene names
gene_names = original_data.var['gene'].values[top_100_idx]
print("Transition genes:", gene_names)

Transition genes: ['Psap' 'Ctsb' 'Fth1' 'Ctss' 'Gpnmb' 'Lgals3' 'B2m' 'Lyz2' 'Fabp5' 'Grn'
 'Vim' 'Clec4n' 'Ctsd' 'Anxa4' 'Cd9' 'Itgb2' 'Sirpa' 'Mrc1' 'Mpeg1' 'Lpl'
 'Ahnak' 'Anpep' 'Wfdc17' 'Clec7a' 'Lgmn' 'Fcer1g' 'Gsn' 'Gpr137b' 'Cstb'
 'Laptm5' 'Timp2' 'Mmp12' 'Cd74' 'Atp6v0d2' 'Itm2b' 'Gns' 'H2-Aa' 'Npc2'
 'Lrpap1' 'C3ar1' 'Atp6v0c' 'Sgpl1' 'Myof' 'Plxna1' 'Cd300c2' 'Cd68'
 'Slc6a6' 'Sat1' 'Emp1' 'Rnh1' 'Fabp4' 'Dab2' 'S100a4' 'Cyba' 'Tyrobp'
 'Lilrb4a' 'Lamp1' 'Tnfaip2' 'Ftl1' 'Akr1a1' 'Bnip3l' 'H2-Eb1' 'Cd44'
 'Bri3' 'Anxa5' 'Iqgap1' 'Btg1' 'Lipa' 'Ctla2a' 'Itgam' 'Lasp1' 'Ptms'
 'Klf6' 'Cybb' 'Mcl1' 'Abcg1' 'Lgals3bp' 'Il7r' 'Tapbp' 'Mmp8' 'Gabarap'
 'Plek' 'S100a6' 'Adgre1' 'Cst3' 'Fam198b' 'Clec4a1' 'Clec4d' 'Mpp1'
 'Cndp2' 'Ly6a' 'Plin2' 'Efhd2' 'Actb' 'Rab7b' 'Ctsa' 'Ms4a6d' 'Csf1r'
 'Neat1' 'Cd52']

Growth Rate Analysis

If the growth term is disabled (use_mass set to False), omit this analysis.

[ ]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
import joblib

# Load dimension reducer
dim_reducer = joblib.load(os.path.join(exp_dir, 'dim_reduction.pkl')) # If no dim reducer is needed, set this to None
device = 'cpu'
f_net = f_net.to(device)

def plot_g_values(df, f_net, dim_reducer=None, device=device):
    # Get all time points
    time_points = df['samples'].unique()

    # Store data for each time point
    data_by_time = {}

    # Calculate g_values for each time point
    for time in time_points:
        subset = df[df['samples'] == time]
        n = dim  # Make sure dim is defined

        # Generate column names
        column_names = [f'x{i}' for i in range(1, n + 1)]

        # Convert each column to tensor and move to device
        tensors = [torch.tensor(subset[col].values, dtype=torch.float32).to(device) for col in column_names]

        # Stack tensors into 2D tensor
        data = torch.stack(tensors, dim=1)
        with torch.no_grad():
            t = torch.tensor([time], dtype=torch.float32).to(device)
            _, g, _, _ = f_net(t, data)

        data_by_time[time] = {'data': subset, 'g_values': g.detach().cpu().numpy()}

    # Combine all g_values
    all_g_values = np.concatenate([content['g_values'] for content in data_by_time.values()])

    # Calculate 95th percentile of g_values
    vmax_value = np.percentile(all_g_values, 95)

    # Initialize color mapper with clipping
    norm = plt.Normalize(vmin=0, vmax=vmax_value, clip=True)

    # Create figure and axis
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot data for each time point on same axis
    for time, content in data_by_time.items():
        subset = content['data']
        g_values = content['g_values']
        n = dim

        column_names = [f'x{i}' for i in range(1, n + 1)]
        new_data = subset[column_names]

        if dim_reducer is not None:
            data_reduced = dim_reducer.transform(new_data)
        else:
            data_reduced = new_data.iloc[:, :2].values

        x = data_reduced[:, 0]
        y = data_reduced[:, 1]

        # Map g_values to colors
        colors = plt.cm.RdYlBu_r(norm(g_values))

        # Plot scatter with labels for legend
        ax.scatter(x, y, c=colors, alpha=0.8, marker='o', s=10)

    ax.set_xlabel('Gene $X_1$')
    ax.set_ylabel('Gene $X_2$')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    # ax.legend()

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap='RdYlBu_r', norm=norm)
    sm.set_array(all_g_values)
    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label('Predicted Growth Rate')

    # Format colorbar ticks
    cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{(x):.2f}'))

    # Save as PDF
    plt.savefig(os.path.join(exp_dir, 'g_values_plot.svg'), bbox_inches='tight', transparent=True)

    plt.show()

# Plot with f_net and df
plot_g_values(df, f_net, dim_reducer=dim_reducer)
../_images/notebook_analysis_26_0.png
[19]:
# Select growth related genes
g_gradients = []
time_points = df['samples'].unique()
for time in time_points:
    subset = df[df['samples'] == time]
    n = dim

    column_names = [f'x{i}' for i in range(1, n + 1)]

    tensors = [torch.tensor(subset[col].values, dtype=torch.float32).to(device) for col in column_names]

    data = torch.stack(tensors, dim=1)
    t = torch.tensor([time], dtype=torch.float32).to(device)
    _, g, _, _ = f_net(t, data)
    grad_outputs = torch.ones_like(g)
    g.backward(gradient=grad_outputs)
    # calculate gradients of growth
    g_grad = data.grad.detach().cpu().numpy()
    g_gradients.append(g_grad)

g_gradients_np = np.concatenate(g_gradients, axis=0)
[20]:
pca_components = original_data.varm['PCs'].T
g_ori = g_gradients_np @ pca_components
top_100_idx = np.argsort(g_ori.mean(axis=0))[-100:][::-1]
# Get the gene names
gene_names = original_data.var['gene'].values[top_100_idx]
print("Growth related genes:", gene_names)

Growth related genes: ['Rps27rt' 'Prss34' 'Rab33b' 'Als2' 'Cdh1' 'Golga3' 'Gm5483' 'BC100530'
 'Mcpt8' 'Crip1' 'Calr' 'Siglech' 'Lgals1' 'Akr1c18' 'Atp1b1' 'Tmed3'
 'Gm37214' 'Fry' 'Stfa3' 'Stfa2' 'Rpl15' 'Igfbp7' 'Dok2' 'Mboat1' 'Ms4a2'
 'Alox5' '2810403D21Rik' 'Ehd3' 'Rap1b' 'Pth1r' 'Hsp90b1' 'Hspa5'
 'Gm26721' 'Gm15402' '4930589L23Rik' 'Gnmt' 'Tgfbi' 'Slc14a1' 'Ubac2'
 'Zfp184' 'Hao1' 'Igsf8' '2010005H15Rik' 'Slc35d3' 'Muc20' 'Havcr1'
 'Slc4a1' 'Alox15' 'Cnrip1' 'Gpr4' 'Serpinb1a' 'P2ry14' 'Gm13709' 'Atoh8'
 'Ly86' 'Prdx1' 'Rab25' 'Perp' 'Hmgn3' 'Ptger3' 'Slc7a8' 'Pltp' 'Arc'
 'Gm11335' 'Mafb' 'Inpp4b' 'Vcl' 'Serpine2' 'Fam178b' 'Cyp4f18' 'Cyp11a1'
 'Hba-a2' 'Fcrla' 'Pdlim4' 'P2ry1' 'Tbxa2r' 'Epx' 'Clec4a3' 'Blnk' 'Prg2'
 'Anxa6' 'Padi2' 'Slc6a9' 'Timp3' 'Cldn11' 'Rsad1' 'Cebpe' 'Gata2' 'Itgb7'
 'D13Ertd608e' 'Rbpms2' 'Gm5416' 'Alas2' 'Klf5' 'Sucnr1' 'Prg3' 'P2ry10'
 'Spint1' 'Dhrs9' 'Akr1c13']

Interpolation

[21]:
from DeepRUOT.utils import euler_sdeint
import random
import joblib
import numpy as np
device = 'cuda'
f_net.to(device)
sf2m_score_model.to(device)
all_times = df['samples'].values
n_times = all_times.max() + 1
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:].to(device).requires_grad_()
print(data_t0.shape)
x0=data_t0.to(device)

dim_reducer = joblib.load(os.path.join(exp_dir, 'dim_reduction.pkl'))

class SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"

    def __init__(self, ode_drift, g, score, input_size=(3, 32, 32), sigma=1.0):
        super().__init__()
        self.drift = ode_drift
        self.score = score
        self.input_size = input_size
        self.sigma = sigma
        self.g_net = g

    # Drift
    def f(self, t, y):
        z, lnw = y
        drift=self.drift(t, z)
        dlnw = self.g_net(t, z)
        num = z.shape[0]
        t = t.expand(num, 1)  # Keep gradient information of t and expand its shape
        return (drift+self.score.compute_gradient(t, z), dlnw)

    # Diffusion
    def g(self, t, y):
        return torch.ones_like(y)*self.sigma

x0_subset = x0.to(device)

x0_subset = x0_subset.to(device)
lnw0 = torch.log(torch.ones(x0_subset.shape[0], 1) / x0_subset.shape[0]).to(device)
initial_state = (x0_subset, lnw0)

# Define SDE object
sde = SDE(f_net.v_net,
          f_net.g_net,
          sf2m_score_model,
          input_size=(2,),
          sigma=config['score_train']['sigma'])


ts_points = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float32)  # 0.5 is the unseen timepoint
print(ts_points)

sde_point, traj_lnw = euler_sdeint(sde, initial_state, dt=0.1, ts=ts_points)
print(sde_point.shape)
print(traj_lnw.shape)
weight = torch.exp(traj_lnw)
weight_normed = weight/weight.sum(dim = 1, keepdim = True)

sde_point_np = sde_point.detach().cpu().numpy()
sde_point_list = sde_point_np.tolist()
sde_point_array = np.array(sde_point_list, dtype=object)
torch.Size([4638, 50])
tensor([0.0000, 0.5000, 1.0000])
torch.Size([3, 4638, 50])
torch.Size([3, 4638, 1])
[23]:
df_new = pd.read_csv(os.path.join(DATA_DIR, config['data']['file_path']))
cell_type = original_data.obs['Cell type annotation'].values
all_labels = pd.Categorical(cell_type)
df_new['Annotation'] = all_labels
df_new
[23]:
samples x1 x2 x3 x4 x5 x6 x7 x8 x9 ... x42 x43 x44 x45 x46 x47 x48 x49 x50 Annotation
0 0 -1.217456 -1.876922 -1.205544 -2.138494 -2.375819 -1.729328 0.651229 -0.338510 0.041160 ... 0.054111 0.278638 0.218402 1.900351 0.188479 -0.624556 -0.143188 0.780808 -0.624784 Undifferentiated
1 0 -5.243580 -1.761129 -1.729367 -1.019093 0.175251 0.097280 -0.423565 -0.407035 -2.576229 ... -0.024745 0.428439 0.536177 0.373199 -0.229468 0.097612 -0.208186 0.435171 0.263994 Undifferentiated
2 0 -5.752447 -1.419319 -2.102163 -0.605931 0.029611 0.116759 -0.268277 -0.513728 -0.409831 ... 0.045515 0.171796 0.163721 0.497674 -0.512460 0.028188 -0.160119 0.131542 0.413676 Undifferentiated
3 0 -4.255497 -2.384707 -0.330012 -1.585662 -2.652925 0.520510 0.198466 -1.104815 -0.800885 ... 0.311630 -0.553470 -0.514273 0.416959 0.150009 0.026914 0.030274 0.200624 0.081662 Undifferentiated
4 0 -4.877692 0.824647 0.232769 2.845089 -1.172621 -0.222511 0.250738 -0.406565 -0.113559 ... 0.059187 0.898496 -0.194554 0.430735 0.015443 -0.698956 1.154398 -0.912015 -0.101607 Undifferentiated
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
49297 2 -7.065715 2.007633 0.619574 3.793502 -1.925619 2.546932 0.242512 0.905567 1.139781 ... -0.090863 0.651469 0.029022 0.714116 -1.062683 -0.118945 0.358995 0.626652 -0.648994 Undifferentiated
49298 2 -5.770710 -1.731957 -0.978848 -1.203649 -2.223124 0.994655 -0.472510 -0.973434 0.074024 ... 0.207697 -0.309746 -0.290887 -0.561833 -0.185721 -0.286728 0.138851 0.289373 -0.958307 Undifferentiated
49299 2 0.250359 -2.772287 1.352441 -2.562624 -2.185664 -2.237595 1.094841 -0.884417 0.565153 ... 0.173202 -0.346234 0.683077 0.718787 -0.206094 0.077363 -0.306639 -0.567678 0.629165 Neutrophil
49300 2 12.654426 0.758303 -6.569150 -1.810055 -3.827207 -6.287392 2.184073 4.394424 -1.419851 ... -1.361348 0.553498 0.251884 -0.941483 -0.814592 -0.282054 0.654892 -0.863675 -0.089109 Monocyte
49301 2 -5.765212 -1.737552 -1.499219 -0.452036 -0.876632 0.868453 -0.872944 -0.954929 0.379570 ... -0.860799 0.515217 0.272714 -0.171447 -0.189407 0.038520 -0.406322 -0.203042 -0.801902 Undifferentiated

49302 rows × 52 columns

[24]:
# Cell annotation
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np

X = df_new.iloc[:,:-1].values
y = df_new['Annotation'].values

label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)

X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

# Classifier
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.LeakyReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

input_size = 51
hidden_size = 128
num_classes = len(label_encoder.classes_)
model = MLP(input_size, hidden_size, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
[ ]:
num_epochs = 10000
model = model.cuda()
X_train = X_train.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
    model.train()
    X_train.requires_grad_(True)
    outputs = model(X_train)
    loss = criterion(outputs, y_train.cuda())
    grad_outputs = torch.ones_like(outputs)
    grads = torch.autograd.grad(
        outputs=outputs,
        inputs=X_train,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    grad_dim1 = grads[:, 0]
    grad_norm = grad_dim1.abs().mean()
    reg_lambda = 1e-2
    weight_decay = 1e-4
    l2_reg = torch.tensor(0., device=loss.device)
    for param in model.parameters():
        if param.requires_grad:
            l2_reg = l2_reg + torch.norm(param, 2) ** 2
    loss = loss + reg_lambda * grad_norm + weight_decay * l2_reg
    X_train.requires_grad_(False)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        model.eval()
        with torch.no_grad():
            outputs = model(X_test.cuda())
            _, predicted = torch.max(outputs, 1)
            accuracy = accuracy_score(y_test, predicted.cpu())
            print(f'Acc: {accuracy:.4f}')
Epoch [100/10000], Loss: 0.1152
Acc: 0.9595
Epoch [200/10000], Loss: 0.0873
Acc: 0.9662
Epoch [300/10000], Loss: 0.0746
Acc: 0.9686
Epoch [400/10000], Loss: 0.0660
Acc: 0.9692
Epoch [500/10000], Loss: 0.0592
Acc: 0.9694
Epoch [600/10000], Loss: 0.0534
Acc: 0.9694
Epoch [700/10000], Loss: 0.0484
Acc: 0.9693
Epoch [800/10000], Loss: 0.0440
Acc: 0.9691
Epoch [900/10000], Loss: 0.0401
Acc: 0.9685
Epoch [1000/10000], Loss: 0.0367
Acc: 0.9682
Epoch [1100/10000], Loss: 0.0339
Acc: 0.9686
Epoch [1200/10000], Loss: 0.0318
Acc: 0.9684
Epoch [1300/10000], Loss: 0.0301
Acc: 0.9675
Epoch [1400/10000], Loss: 0.0287
Acc: 0.9674
Epoch [1500/10000], Loss: 0.0276
Acc: 0.9668
Epoch [1600/10000], Loss: 0.0265
Acc: 0.9669
Epoch [1700/10000], Loss: 0.0257
Acc: 0.9666
Epoch [1800/10000], Loss: 0.0250
Acc: 0.9663
Epoch [1900/10000], Loss: 0.0243
Acc: 0.9668
Epoch [2000/10000], Loss: 0.0237
Acc: 0.9665
Epoch [2100/10000], Loss: 0.0233
Acc: 0.9677
Epoch [2200/10000], Loss: 0.0228
Acc: 0.9674
Epoch [2300/10000], Loss: 0.0225
Acc: 0.9673
Epoch [2400/10000], Loss: 0.0222
Acc: 0.9674
Epoch [2500/10000], Loss: 0.0219
Acc: 0.9673
Epoch [2600/10000], Loss: 0.0214
Acc: 0.9674
Epoch [2700/10000], Loss: 0.0212
Acc: 0.9670
Epoch [2800/10000], Loss: 0.0210
Acc: 0.9667
Epoch [2900/10000], Loss: 0.0217
Acc: 0.9677
Epoch [3000/10000], Loss: 0.0207
Acc: 0.9665
Epoch [3100/10000], Loss: 0.0206
Acc: 0.9663
Epoch [3200/10000], Loss: 0.0204
Acc: 0.9664
Epoch [3300/10000], Loss: 0.0203
Acc: 0.9665
Epoch [3400/10000], Loss: 0.0202
Acc: 0.9665
Epoch [3500/10000], Loss: 0.0201
Acc: 0.9665
Epoch [3600/10000], Loss: 0.0200
Acc: 0.9667
Epoch [3700/10000], Loss: 0.0199
Acc: 0.9665
Epoch [3800/10000], Loss: 0.0198
Acc: 0.9669
Epoch [3900/10000], Loss: 0.0198
Acc: 0.9667
Epoch [4000/10000], Loss: 0.0197
Acc: 0.9665
Epoch [4100/10000], Loss: 0.0196
Acc: 0.9667
Epoch [4200/10000], Loss: 0.0195
Acc: 0.9666
Epoch [4300/10000], Loss: 0.0194
Acc: 0.9668
Epoch [4400/10000], Loss: 0.0194
Acc: 0.9668
Epoch [4500/10000], Loss: 0.0193
Acc: 0.9665
Epoch [4600/10000], Loss: 0.0192
Acc: 0.9664
Epoch [4700/10000], Loss: 0.0191
Acc: 0.9662
Epoch [4800/10000], Loss: 0.0191
Acc: 0.9658
Epoch [4900/10000], Loss: 0.0191
Acc: 0.9657
Epoch [5000/10000], Loss: 0.0189
Acc: 0.9660
Epoch [5100/10000], Loss: 0.0189
Acc: 0.9661
Epoch [5200/10000], Loss: 0.0188
Acc: 0.9658
Epoch [5300/10000], Loss: 0.0187
Acc: 0.9657
Epoch [5400/10000], Loss: 0.0187
Acc: 0.9658
Epoch [5500/10000], Loss: 0.0196
Acc: 0.9667
Epoch [5600/10000], Loss: 0.0191
Acc: 0.9659
Epoch [5700/10000], Loss: 0.0189
Acc: 0.9658
Epoch [5800/10000], Loss: 0.0188
Acc: 0.9657
Epoch [5900/10000], Loss: 0.0187
Acc: 0.9657
Epoch [6000/10000], Loss: 0.0187
Acc: 0.9658
Epoch [6100/10000], Loss: 0.0186
Acc: 0.9658
Epoch [6200/10000], Loss: 0.0186
Acc: 0.9659
Epoch [6300/10000], Loss: 0.0186
Acc: 0.9660
Epoch [6400/10000], Loss: 0.0185
Acc: 0.9658
Epoch [6500/10000], Loss: 0.0185
Acc: 0.9658
Epoch [6600/10000], Loss: 0.0185
Acc: 0.9658
Epoch [6700/10000], Loss: 0.0185
Acc: 0.9658
Epoch [6800/10000], Loss: 0.0184
Acc: 0.9657
Epoch [6900/10000], Loss: 0.0184
Acc: 0.9655
Epoch [7000/10000], Loss: 0.0184
Acc: 0.9655
Epoch [7100/10000], Loss: 0.0184
Acc: 0.9657
Epoch [7200/10000], Loss: 0.0184
Acc: 0.9657
Epoch [7300/10000], Loss: 0.0183
Acc: 0.9657
Epoch [7400/10000], Loss: 0.0183
Acc: 0.9657
Epoch [7500/10000], Loss: 0.0183
Acc: 0.9657
Epoch [7600/10000], Loss: 0.0183
Acc: 0.9658
Epoch [7700/10000], Loss: 0.0182
Acc: 0.9658
Epoch [7800/10000], Loss: 0.0182
Acc: 0.9659
Epoch [7900/10000], Loss: 0.0182
Acc: 0.9657
Epoch [8000/10000], Loss: 0.0182
Acc: 0.9656
Epoch [8100/10000], Loss: 0.0181
Acc: 0.9658
Epoch [8200/10000], Loss: 0.0181
Acc: 0.9658
Epoch [8300/10000], Loss: 0.0181
Acc: 0.9654
Epoch [8400/10000], Loss: 0.0181
Acc: 0.9657
Epoch [8500/10000], Loss: 0.0180
Acc: 0.9655
Epoch [8600/10000], Loss: 0.0182
Acc: 0.9654
Epoch [8700/10000], Loss: 0.0180
Acc: 0.9651
Epoch [8800/10000], Loss: 0.0207
Acc: 0.9667
Epoch [8900/10000], Loss: 0.0184
Acc: 0.9652
Epoch [9000/10000], Loss: 0.0182
Acc: 0.9647
Epoch [9100/10000], Loss: 0.0181
Acc: 0.9648
Epoch [9200/10000], Loss: 0.0180
Acc: 0.9648
Epoch [9300/10000], Loss: 0.0180
Acc: 0.9650
Epoch [9400/10000], Loss: 0.0180
Acc: 0.9650
Epoch [9500/10000], Loss: 0.0179
Acc: 0.9650
Epoch [9600/10000], Loss: 0.0179
Acc: 0.9650
Epoch [9700/10000], Loss: 0.0179
Acc: 0.9651
Epoch [9800/10000], Loss: 0.0179
Acc: 0.9652
Epoch [9900/10000], Loss: 0.0179
Acc: 0.9651
Epoch [10000/10000], Loss: 0.0178
Acc: 0.9650
[26]:
# Predict Cell type
torch.save(model.state_dict(), exp_dir + '/mlp_classifier.pth')
model.eval()
model.to('cuda')
predicted_labels_list = []

ts = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float32)
predicted_labels_list.append(df_new[df_new['samples']==0]['Annotation'].values)

for i in range(1, len(sde_point)):
    t = ts[i]
    traj_t = np.array(sde_point[i].detach().cpu().numpy(), dtype = np.float64)
    traj_t = torch.tensor(traj_t)
    n_samples = traj_t.shape[0]

    samples_t = t * torch.ones((n_samples, 1))
    input_t = torch.cat((samples_t, traj_t), dim=1)

    with torch.no_grad():
        outputs = model(input_t.float().cuda())
        _, predicted = torch.max(outputs, 1)
        predicted_labels = label_encoder.inverse_transform(predicted.detach().cpu().numpy())

    predicted_labels_list.append(predicted_labels)


import matplotlib

cmap = matplotlib.cm.get_cmap('tab20')

all_labels = np.unique(np.concatenate(predicted_labels_list))
label_to_int = {label: idx for idx, label in enumerate(all_labels)}

predicted_colors_list = []
for labels in predicted_labels_list:
    label_indices = np.array([label_to_int[label] for label in labels])
    colors = cmap(label_indices % cmap.N)
    predicted_colors_list.append(colors)
[27]:
# Plot interpolation data
data_slice = sde_point[1]
data_plot = data_slice.detach().cpu().numpy()
data_plot_reduced = dim_reducer.transform(data_plot)

fig, ax = plt.subplots(figsize=(6, 4))
ax.scatter(data_plot_reduced[:, 0], data_plot_reduced[:, 1], c=predicted_colors_list[1], alpha=1.0, s=25)
ax.set_axis_off()
ax.set_title('Interpolation')
plt.tight_layout()
plt.show()
../_images/notebook_analysis_35_0.png

GRN

[30]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
def run_grn_v_only(f_net, df, dim, output_path, device='cuda',
                   use_gene=False, genes=""):
    """
    Compute the Jacobian using only f_net.v_net, optionally project to gene space and plot selected gene interactions.

    Args:
        f_net          : FNet model with loaded weights
        df             : DataFrame containing 'samples' and x1...xd
        dim            : Data dimension (e.g., 50)
        output_path    : Path to save results
        device         : Device to run on (default 'cuda')
        use_gene       : Whether to project to original gene space
        genes          : Comma-separated gene indices (1-based), e.g. "3,5,7"
    """

    # If gene-level analysis, load PCA components for projection
    if use_gene:
        pca_components = original_data.varm['PCs'].T
        W = pca_components.T.astype(np.float32)   # (G, dim)
        G = W.shape[0]
        print(f"PCA model loaded, number of original genes = {G}")

        # Parse genes
        if not genes.strip():
            raise ValueError("When use_gene=True, the genes parameter must be provided")
        try:
            gene_idx = [int(x.strip()) - 1 for x in genes.split(',')]
        except ValueError:
            raise ValueError("genes must be comma-separated integers")
        if any(i < 0 or i >= G for i in gene_idx):
            raise IndexError("gene indices out of range")
    else:
        gene_idx = []

    f_net = f_net.to(device)

    all_times = df['samples'].values
    all_times_u = np.unique(all_times)
    all_data = df[[f'x{i}' for i in range(1, dim + 1)]].values

    for time_pt in tqdm(all_times_u, desc="Processing time points"):
        mask = all_times == time_pt
        z_np = all_data[mask].astype(np.float32)
        t_np = np.full((z_np.shape[0], 1), time_pt, np.float32)

        z_t = torch.tensor(z_np, device=device, dtype=torch.float32).requires_grad_(True)
        t_t = torch.tensor(t_np, device=device, dtype=torch.float32)

        # Compute velocity field
        v = f_net.v_net(t_t, z_t)

        # Jacobian computation (GPU)
        def jacobian_batch(f, z):
            B, m = f.shape
            _, n = z.shape
            jac = torch.zeros(B, m, n, device=z.device)
            for i in range(m):
                grad = torch.autograd.grad(
                    outputs=f[:, i],
                    inputs=z,
                    grad_outputs=torch.ones_like(f[:, i]),
                    retain_graph=True,
                    create_graph=True,
                    only_inputs=True
                )[0]
                jac[:, i, :] = grad
            return jac

        jac = jacobian_batch(v, z_t).mean(0).detach().cpu().numpy()  # (dim, dim)

        # Save dim x dim Jacobian
        np.savetxt(os.path.join(output_path, f'jac_t{time_pt}.csv'),
                   jac, delimiter=',', fmt='%.6f')

        if use_gene:
            # Project to original gene space
            jac_gene = W @ jac @ W.T          # (G, G)
            np.savetxt(os.path.join(output_path, f'jac_gene_t{time_pt}.csv'),
                       jac_gene, delimiter=',', fmt='%.6f')

            # Select specified genes
            sub_jac = jac_gene[np.ix_(gene_idx, gene_idx)]
            gene_names = [f"gene{i+1}" for i in gene_idx]

            plt.figure(figsize=(max(2, len(gene_idx)) + 2, max(2, len(gene_idx))), dpi=300)
            sns.heatmap(sub_jac, cmap="coolwarm", square=True,
                        xticklabels=gene_names, yticklabels=gene_names)
            plt.title(f'GRN genes {",".join(str(i+1) for i in gene_idx)} (t={time_pt})')
            plt.tight_layout()
            plt.show()
            plt.savefig(os.path.join(output_path,
                                     f'GRN_genes_{",".join(str(i+1) for i in gene_idx)}_t{time_pt}.pdf'),
                        format='pdf', bbox_inches='tight')
            plt.close()
        else:
            # Only plot the top-left 3x3 block
            dim_small = min(3, dim)
            jac_small = jac[:dim_small, :dim_small]
            plt.figure(figsize=(4, 3), dpi=300)
            sns.heatmap(jac_small, cmap="coolwarm", square=True,
                        xticklabels=[f'x{i}' for i in range(1, dim_small + 1)],
                        yticklabels=[f'x{i}' for i in range(1, dim_small + 1)])
            plt.title(f'Jacobian v(t={time_pt})')
            plt.tight_layout()
            plt.show()
            plt.savefig(os.path.join(output_path, f'Average_jac_v_t{time_pt}.pdf'),
                        format='pdf', bbox_inches='tight')
            plt.close()


# run_grn_v_only(f_net, df, dim, exp_dir, device='cuda',
#                    use_gene=False)
run_grn_v_only(f_net, df, dim, exp_dir, device='cuda',
                   use_gene=True, genes="2,5,7")


PCA model loaded, number of original genes = 2447
Processing time points:   0%|          | 0/3 [00:00<?, ?it/s]
../_images/notebook_analysis_37_2.png
Processing time points:  33%|███▎      | 1/3 [00:01<00:02,  1.43s/it]
../_images/notebook_analysis_37_4.png
Processing time points:  67%|██████▋   | 2/3 [00:02<00:01,  1.41s/it]
../_images/notebook_analysis_37_6.png
Processing time points: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it]