[1]:
import os, random, argparse, pandas as pd, numpy as np, seaborn as sns
from tqdm import tqdm
import torch, torch.nn as nn


# set random seed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
SEED = 10
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)


# load package requirments
from DeepRUOT.losses import MMD_loss, OT_loss1, OT_loss2, Density_loss, Local_density_loss
from DeepRUOT.utils import group_extract, sample, to_np, generate_steps, cal_mass_loss, parser, _valid_criterions
from DeepRUOT.plots import plot_comparision, plot_losses
from DeepRUOT.train import train_un1
from DeepRUOT.models import velocityNet, growthNet, scoreNet, dediffusionNet, indediffusionNet, FNet, ODEFunc2
from DeepRUOT.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from DeepRUOT.exp import setup_exp
from DeepRUOT.eval import generate_plot_data
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint

# for geodesic learning
from scipy.spatial import distance_matrix
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS
[2]:
import torch.optim as optim

dim=10
f_net = FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu')


import sys

# Simulate the command-line arguments
sys.argv = [
    'DeepRUOT Training',
    '-d', 'file',
    '-c', 'ot1',
    '-n', 'emt',
]

args = parser.parse_args()
opts = vars(args)

# Display the parsed arguments
print(opts)

device = torch.device('cuda')
device

{'dataset': 'file', 'time_col': None, 'name': 'emt', 'output_dir': '/lustre/home/2301110060/DeepRUOT/results', 'local_epochs': 5, 'epochs': 15, 'local_post_epochs': 5, 'criterion': 'ot1', 'batches': 100, 'cuda': True, 'sample_size': 100, 'sample_with_replacement': False, 'hold_one_out': True, 'hold_out': 'random', 'apply_losses_in_time': True, 'top_k': 5, 'hinge_value': 0.01, 'use_density_loss': True, 'use_local_density': False, 'lambda_density': 1.0, 'lambda_density_local': 1.0, 'lambda_local': 0.2, 'lambda_global': 0.8, 'model_layers': [64], 'use_geo': False, 'geo_layers': [32], 'geo_features': 5, 'n_points': 100, 'n_trajectories': 30, 'n_bins': 100}
[2]:
device(type='cuda')
[3]:
df=pd.read_csv(DATA_DIR + '/emt.csv')
 # make output dir
if not os.path.isdir(opts['output_dir']):
    os.makedirs(opts['output_dir'])
exp_dir, logger = setup_exp(opts['output_dir'], opts, opts['name'])
  # load dataset
logger.info(f'Loading dataset')

[4]:
 # setup groups
groups = sorted(df.samples.unique())
steps = generate_steps(groups)
logger.info(f'Defining model')
use_geo = opts['use_geo']
model_layers = opts['model_layers']
model_features = len(df.columns) - 1

logger.info(f'Defining optimizer and criterion')
optimizer = torch.optim.Adam(f_net.parameters())

opts['criterion']='ot1'
criterion =  _valid_criterions[opts['criterion']]()

logger.info(f'Extracting parameters')
use_cuda = torch.cuda.is_available() and opts['cuda']
sample_size = (opts['sample_size'], )
sample_with_replacement = opts['sample_with_replacement' ]
apply_losses_in_time = opts['apply_losses_in_time']


n_local_epochs = opts['local_epochs']
n_epochs = opts['epochs']
n_post_local_epochs = opts['local_post_epochs']
n_batches = opts['batches']

hold_one_out = opts['hold_one_out']
hold_out = opts['hold_out']

hinge_value = opts['hinge_value']
top_k = opts['top_k']
lambda_density = opts['lambda_density']
lambda_density_local = opts['lambda_density_local']
use_density_loss = opts['use_density_loss']
use_local_density = opts['use_local_density']

lambda_local = opts['lambda_local']
lambda_global = opts['lambda_global']

n_points=opts['n_points']
n_trajectories=opts['n_trajectories']
n_bins=opts['n_bins']


local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []
[5]:
f_net=f_net.to(device)
f_net
[5]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=10, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[6]:
initial_size=df[df['samples']==0].x1.shape[0]
initial_size
[6]:
577
[7]:
sample_sizes = df.groupby('samples').size()
ref0 = sample_sizes / sample_sizes.iloc[0]
relative_mass = torch.tensor(ref0.values)
relative_mass
[7]:
tensor([1.0000, 1.5338, 1.3657, 1.5303], dtype=torch.float64)
[8]:
sample_size = (df[df['samples']==0.0].values.shape[0],)

PreTrain velocity and growth

[9]:
    if n_local_epochs > 0:
        logger.info(f'Beginning pretraining')
        for epoch in tqdm(range(1), desc='Pretraining Epoch'):
            l_loss, b_loss, g_loss = train_un1(
                f_net, df, groups, optimizer,30,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out,
                 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=1, lambda_energy=0.001,
                 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

Pretraining Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.1700, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0193, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3781, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0798, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.6073, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5121, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6072728633880615. Model saved.
Otloss
tensor(0.1137, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0103, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1791, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0220, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2457, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0462, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2457018494606018. Model saved.
Otloss
tensor(0.0945, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0075, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1120, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0085, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1666, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0187, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16664665937423706. Model saved.
Otloss
tensor(0.0936, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0078, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1187, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0084, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2180, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0236, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0930, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0075, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1173, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0091, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1955, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0210, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0862, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0066, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1011, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0075, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1500, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0151, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14996205270290375. Model saved.
Otloss
tensor(0.0804, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0057, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0945, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0075, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1376, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0153, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.13762015104293823. Model saved.
Otloss
tensor(0.0775, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0050, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0960, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0090, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1346, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0146, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1345514953136444. Model saved.
Otloss
tensor(0.0760, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0976, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0087, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1217, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0128, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.12168321758508682. Model saved.
Otloss
tensor(0.0742, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0977, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0076, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1055, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0095, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.10546368360519409. Model saved.
Otloss
tensor(0.0718, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0970, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0070, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0953, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0074, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09534205496311188. Model saved.
Otloss
tensor(0.0698, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0040, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0946, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0066, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0920, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0063, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09199593216180801. Model saved.
Otloss
tensor(0.0696, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0042, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0920, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0071, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0946, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0069, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0703, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0043, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0886, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0065, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0975, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0076, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0699, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0043, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0844, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0060, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0961, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0074, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0694, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0042, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0837, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0056, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0940, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0069, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0690, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0040, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0852, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0054, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0904, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0066, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09037873893976212. Model saved.
Otloss
tensor(0.0682, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0040, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0854, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0058, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0855, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0061, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08546467125415802. Model saved.
Otloss
tensor(0.0677, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0040, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0848, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0059, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0834, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0062, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0834236741065979. Model saved.
Otloss
tensor(0.0675, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0836, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0056, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0833, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0059, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08330996334552765. Model saved.
Otloss
tensor(0.0673, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0813, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0049, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0823, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0058, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0823383778333664. Model saved.
Otloss
tensor(0.0671, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0039, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0792, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0047, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0826, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0057, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0669, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0039, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0780, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0829, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0055, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0665, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0769, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0044, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0053, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08051622658967972. Model saved.
Otloss
tensor(0.0660, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0768, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0800, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0051, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07996969670057297. Model saved.
Otloss
tensor(0.0655, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0766, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0046, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0785, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0049, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07851728796958923. Model saved.
Otloss
tensor(0.0652, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0754, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0043, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0782, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0052, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0781906396150589. Model saved.
Otloss
tensor(0.0650, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0734, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0790, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0051, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0648, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0722, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0039, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0777, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0044, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07765759527683258. Model saved.
Otloss
tensor(0.0645, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0723, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0039, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0776, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [00:56<00:00, 56.98s/it]
New minimum otloss found: 0.07756981253623962. Model saved.

[10]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
    param.requires_grad = False

[11]:
    if n_local_epochs > 0:
        logger.info(f'Beginning pretraining')
        for epoch in tqdm(range(1), desc='Pretraining Epoch'):
            l_loss, b_loss, g_loss = train_un1(
                f_net, df, groups, optimizer,10,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out,
                 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=0, lambda_energy=0.001,
                 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

Pretraining Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.0642, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0038, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0713, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0035, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0780, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0048, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07804079353809357. Model saved.
Otloss
tensor(0.0640, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0708, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0035, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0761, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0048, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07610562443733215. Model saved.
Otloss
tensor(0.0639, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0715, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0039, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0775, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0047, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0636, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0697, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0031, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0764, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0045, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0638, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0701, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0033, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0768, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0046, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0634, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0036, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0693, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0033, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0773, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0044, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0635, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0037, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0688, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0033, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0761, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0047, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07606751471757889. Model saved.
Otloss
tensor(0.0631, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0036, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0692, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0031, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0779, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0048, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0632, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0035, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0695, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0032, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0754, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0047, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.07541188597679138. Model saved.
Otloss
tensor(0.0631, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0036, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0685, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0032, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0764, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0044, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [00:17<00:00, 17.16s/it]
[12]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
[12]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=10, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[13]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_g_values(df, f_net, device=device, output_file='plot.pdf'):
    time_points = df['samples'].unique()

    data_by_time = {}

    for time in time_points:
        subset = df[df['samples'] == time]
        n = 10

        column_names = [f'x{i}' for i in range(1, n + 1)]

        tensors = [torch.tensor(subset[col].values, dtype=torch.float32).to(device) for col in column_names]

        data = torch.stack(tensors, dim=1)
        with torch.no_grad():
            t = torch.tensor([time], dtype=torch.float32).to(device)
            _, g, _, _ = f_net(t, data)

        data_by_time[time] = {'data': subset, 'g_values': g.detach().cpu().numpy()}

    all_g_values = np.concatenate([content['g_values'] for content in data_by_time.values()])

    vmax_value = np.percentile(all_g_values, 99)

    norm = plt.Normalize(vmin=all_g_values.min(), vmax=vmax_value, clip=True)

    fig, ax = plt.subplots(figsize=(12, 8))

    for time, content in data_by_time.items():
        subset = content['data']
        g_values = content['g_values']
        x = subset['x1']
        y = subset['x2']

        colors = plt.cm.rainbow(norm(g_values))

        ax.scatter(x, y, color=colors, label=f'Time {time}', alpha=0.7, marker='o')

    ax.set_xlabel('Gene $X_1$')
    ax.set_ylabel('Gene $X_2$')

    sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
    sm.set_array(all_g_values)
    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label('Normalized predicted growth rate')

    cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{norm(x):.2f}'))

    plt.show()

plot_g_values(df, f_net, output_file='gene_growth_pre_post.pdf')
../_images/notebook_emt_13_0.png
[15]:
    # generate plot data
    f_net.to('cuda')
    generated, trajectories = generate_plot_data(
        f_net, df, n_points=400, n_trajectories=50, n_bins=100,
        sample_with_replacement=True, use_cuda=use_cuda, samples_key='samples',
        logger=logger
    )

tensor([0., 1., 2., 3.], device='cuda:0')
<class 'torch.Tensor'>
tensor([0.0000, 0.0303, 0.0606, 0.0909, 0.1212, 0.1515, 0.1818, 0.2121, 0.2424,
        0.2727, 0.3030, 0.3333, 0.3636, 0.3939, 0.4242, 0.4545, 0.4848, 0.5152,
        0.5455, 0.5758, 0.6061, 0.6364, 0.6667, 0.6970, 0.7273, 0.7576, 0.7879,
        0.8182, 0.8485, 0.8788, 0.9091, 0.9394, 0.9697, 1.0000, 1.0303, 1.0606,
        1.0909, 1.1212, 1.1515, 1.1818, 1.2121, 1.2424, 1.2727, 1.3030, 1.3333,
        1.3636, 1.3939, 1.4242, 1.4545, 1.4848, 1.5152, 1.5455, 1.5758, 1.6061,
        1.6364, 1.6667, 1.6970, 1.7273, 1.7576, 1.7879, 1.8182, 1.8485, 1.8788,
        1.9091, 1.9394, 1.9697, 2.0000, 2.0303, 2.0606, 2.0909, 2.1212, 2.1515,
        2.1818, 2.2121, 2.2424, 2.2727, 2.3030, 2.3333, 2.3636, 2.3939, 2.4242,
        2.4545, 2.4848, 2.5152, 2.5455, 2.5758, 2.6061, 2.6364, 2.6667, 2.6970,
        2.7273, 2.7576, 2.7879, 2.8182, 2.8485, 2.8788, 2.9091, 2.9394, 2.9697,
        3.0000], device='cuda:0')
<class 'torch.Tensor'>
[17]:
    plot_comparision(
        df, generated, trajectories,
        palette = 'viridis', df_time_key='samples',
        save=True, path=exp_dir, file='comparision.png',
        x='x1', y='x2', is_3d=False
    )
[17]:
../_images/notebook_emt_15_0.png

Pretrain score

[18]:
import pandas as pd
import anndata as ad

f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)

print("DataFrame shape:", df.shape)
print("DataFrame columns:", df.columns)
n=dim
samples = df['samples'].values
column_names = [f'x{i}' for i in range(1, n + 1)]

obsm_data = df[column_names].values
print("obsm_data shape:", obsm_data.shape)

adata = ad.AnnData(obs=pd.DataFrame(index=samples))


adata.obsm['X_pca'] = obsm_data


adata_loaded = adata


print(adata_loaded)

import scanpy as sc

adata.obs['samples']=df['samples'].values
sc.pl.scatter(adata, basis="pca", color="samples")

n_times = len(adata.obs["samples"].unique())
print(n_times)
X = [
    adata.obsm["X_pca"][adata.obs["samples"] == t]
    for t in range(n_times)
]
DataFrame shape: (3133, 11)
DataFrame columns: Index(['samples', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10'], dtype='object')
obsm_data shape: (3133, 10)
AnnData object with n_obs × n_vars = 3133 × 0
    obsm: 'X_pca'
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
../_images/notebook_emt_17_2.png
4
[25]:
from DeepRUOT.utils import OTPlanSampler, ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher,SchrodingerBridgeConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,generate_state_trajectory
from DeepRUOT.models import scoreNet2

batch_size = df[df['samples']==0].x1.shape[0]
sigma = 0.05
time = torch.Tensor(groups)
SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
sf2m_score_model=scoreNet2(in_out_dim=dim, hidden_dim=128,  activation='leakyrelu').float().to(device)
sf2m_optimizer = torch.optim.Adam(
    list(sf2m_score_model.parameters()), 1e-4
)
trajectory = generate_state_trajectory(X, n_times,batch_size, f_net, time, device)
[26]:
from DeepRUOT.utils import get_batch

max_norm_ut = torch.tensor(0.0)
lambda_penalty=0
for i in tqdm(range(3001)):
    sf2m_optimizer.zero_grad()
    t, xt, ut,eps = get_batch(SF2M, X, trajectory,batch_size, n_times, return_noise=True)
    t=torch.unsqueeze(t,1)
    lambda_t = SF2M.compute_lambda(t % 1)
    value_st=sf2m_score_model(t, xt)
    st = sf2m_score_model.compute_gradient(t, xt)
    positive_st = torch.relu(value_st)
    penalty = lambda_penalty * torch.max(positive_st)

    # max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)
    score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)
    if i % 100 == 0:
        print(torch.max(positive_st))
        print(f"{i}:  {score_loss.item():0.2f}")
    loss = score_loss+penalty

    loss.backward()
    sf2m_optimizer.step()
  0%|          | 2/3001 [00:00<05:05,  9.83it/s]
tensor(0.0455, device='cuda:0', grad_fn=<MaxBackward1>)
0:  1.01
  3%|▎         | 102/3001 [00:10<04:52,  9.90it/s]
tensor(0.0121, device='cuda:0', grad_fn=<MaxBackward1>)
100:  0.99
  7%|▋         | 202/3001 [00:20<04:43,  9.89it/s]
tensor(0.0203, device='cuda:0', grad_fn=<MaxBackward1>)
200:  1.00
 10%|█         | 302/3001 [00:30<04:31,  9.94it/s]
tensor(0.0258, device='cuda:0', grad_fn=<MaxBackward1>)
300:  0.99
 13%|█▎        | 402/3001 [00:40<04:21,  9.94it/s]
tensor(0.0272, device='cuda:0', grad_fn=<MaxBackward1>)
400:  0.99
 17%|█▋        | 502/3001 [00:50<04:11,  9.94it/s]
tensor(0.0335, device='cuda:0', grad_fn=<MaxBackward1>)
500:  0.98
 20%|██        | 602/3001 [01:00<04:01,  9.94it/s]
tensor(0.0322, device='cuda:0', grad_fn=<MaxBackward1>)
600:  0.98
 23%|██▎       | 702/3001 [01:10<03:51,  9.93it/s]
tensor(0.0401, device='cuda:0', grad_fn=<MaxBackward1>)
700:  0.99
 27%|██▋       | 802/3001 [01:20<03:41,  9.94it/s]
tensor(0.0458, device='cuda:0', grad_fn=<MaxBackward1>)
800:  1.00
 30%|███       | 902/3001 [01:30<03:31,  9.95it/s]
tensor(0.0484, device='cuda:0', grad_fn=<MaxBackward1>)
900:  0.97
 33%|███▎      | 1002/3001 [01:40<03:20,  9.95it/s]
tensor(0.0477, device='cuda:0', grad_fn=<MaxBackward1>)
1000:  0.99
 37%|███▋      | 1102/3001 [01:50<03:11,  9.93it/s]
tensor(0.0474, device='cuda:0', grad_fn=<MaxBackward1>)
1100:  0.98
 40%|████      | 1202/3001 [02:00<03:01,  9.93it/s]
tensor(0.0478, device='cuda:0', grad_fn=<MaxBackward1>)
1200:  1.00
 43%|████▎     | 1302/3001 [02:10<02:51,  9.92it/s]
tensor(0.0546, device='cuda:0', grad_fn=<MaxBackward1>)
1300:  1.01
 47%|████▋     | 1402/3001 [02:20<02:41,  9.89it/s]
tensor(0.0539, device='cuda:0', grad_fn=<MaxBackward1>)
1400:  0.99
 50%|█████     | 1502/3001 [02:30<02:30,  9.95it/s]
tensor(0.0583, device='cuda:0', grad_fn=<MaxBackward1>)
1500:  0.98
 53%|█████▎    | 1602/3001 [02:40<02:20,  9.95it/s]
tensor(0.0567, device='cuda:0', grad_fn=<MaxBackward1>)
1600:  0.99
 57%|█████▋    | 1702/3001 [02:51<02:10,  9.95it/s]
tensor(0.0619, device='cuda:0', grad_fn=<MaxBackward1>)
1700:  0.99
 60%|██████    | 1802/3001 [03:01<02:00,  9.95it/s]
tensor(0.0643, device='cuda:0', grad_fn=<MaxBackward1>)
1800:  0.98
 63%|██████▎   | 1902/3001 [03:11<01:50,  9.93it/s]
tensor(0.0706, device='cuda:0', grad_fn=<MaxBackward1>)
1900:  0.98
 67%|██████▋   | 2002/3001 [03:21<01:40,  9.95it/s]
tensor(0.0707, device='cuda:0', grad_fn=<MaxBackward1>)
2000:  0.98
 70%|███████   | 2102/3001 [03:31<01:30,  9.94it/s]
tensor(0.0780, device='cuda:0', grad_fn=<MaxBackward1>)
2100:  1.00
 73%|███████▎  | 2202/3001 [03:41<01:20,  9.94it/s]
tensor(0.0758, device='cuda:0', grad_fn=<MaxBackward1>)
2200:  0.99
 77%|███████▋  | 2302/3001 [03:51<01:10,  9.94it/s]
tensor(0.0805, device='cuda:0', grad_fn=<MaxBackward1>)
2300:  1.01
 80%|████████  | 2402/3001 [04:01<01:00,  9.92it/s]
tensor(0.0822, device='cuda:0', grad_fn=<MaxBackward1>)
2400:  1.00
 83%|████████▎ | 2502/3001 [04:11<00:50,  9.91it/s]
tensor(0.0834, device='cuda:0', grad_fn=<MaxBackward1>)
2500:  0.98
 87%|████████▋ | 2602/3001 [04:21<00:40,  9.91it/s]
tensor(0.0845, device='cuda:0', grad_fn=<MaxBackward1>)
2600:  0.99
 90%|█████████ | 2702/3001 [04:31<00:30,  9.90it/s]
tensor(0.0882, device='cuda:0', grad_fn=<MaxBackward1>)
2700:  0.98
 93%|█████████▎| 2802/3001 [04:41<00:20,  9.94it/s]
tensor(0.0902, device='cuda:0', grad_fn=<MaxBackward1>)
2800:  0.99
 97%|█████████▋| 2902/3001 [04:51<00:09,  9.94it/s]
tensor(0.0921, device='cuda:0', grad_fn=<MaxBackward1>)
2900:  0.99
100%|██████████| 3001/3001 [05:01<00:00,  9.95it/s]
tensor(0.0913, device='cuda:0', grad_fn=<MaxBackward1>)
3000:  0.99

[27]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model'))
[30]:
import numpy as np
import matplotlib.pyplot as plt
import torch



x_range = np.linspace(0, 1, 100)
y_range = np.linspace(0, 1, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)

grid_points_tensor = torch.tensor(grid_points).float().to(device)
padding_dim=dim-2
padding_tensor = torch.zeros(10000, padding_dim).to(device)

expanded_tensor = torch.cat((grid_points_tensor, padding_tensor), dim=1)

t_value = 2.0
t_tensor = torch.tensor([t_value] * grid_points.shape[0]).unsqueeze(1).float().to(device)
print(t_tensor.shape)

expanded_tensor.requires_grad_(True)
log_density_values = sf2m_score_model(t_tensor, expanded_tensor)

density_values=torch.exp(log_density_values)

log_density_values.backward(torch.ones_like(log_density_values))
gradients = expanded_tensor.grad
print(gradients.shape)
gradients_np = gradients.cpu().detach().numpy().reshape(100, 100, dim)

step = 5
xv_quiver = xv[::step, ::step]
yv_quiver = yv[::step, ::step]
gradients_np_quiver = gradients_np[::step, ::step, :]

plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, -log_density_values.cpu().detach().numpy().reshape(100, 100), levels=50, cmap='rainbow')
plt.colorbar(label=' ')
plt.quiver(xv_quiver, yv_quiver, gradients_np_quiver[:, :, 0], gradients_np_quiver[:, :, 1], color='white', scale=1)
plt.xlabel('x1')
plt.ylabel('x2')
#plt.savefig('score_t_0_gaussian.pdf')
plt.show()
torch.Size([10000, 1])
torch.Size([10000, 10])
../_images/notebook_emt_21_1.png

Train

[31]:
sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, 'score_model'),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
    param.requires_grad = True
[31]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=10, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[13]:
import numpy as np
from DeepRUOT.utils import density1
import matplotlib.pyplot as plt
import torch

datatime0=torch.zeros(df[df['samples']==0].x1.shape[0],dim)
datatime0[:, 0] = torch.tensor(df[df['samples'] == 0].x1.values, dtype=torch.float32)
datatime0[:, 1] = torch.tensor(df[df['samples'] == 0].x2.values, dtype=torch.float32)
device='cpu'

x_range = np.linspace(0, 1, 100)
y_range = np.linspace(0, 1, 100)
xv, yv = np.meshgrid(x_range, y_range)
num_grid_points = xv.size  # Total number of grid points (10000)

# Initialize a grid with 10 dimensions: first two vary, others are zero
grid_points = np.zeros((num_grid_points, dim), dtype=np.float32)
grid_points[:, 0] = xv.flatten()
grid_points[:, 1] = yv.flatten()

# Convert grid points to a PyTorch tensor and move to the appropriate device
grid_points_tensor = torch.tensor(grid_points).float().to(device)

grid_points_tensor = torch.tensor(grid_points).float().to(device)

with torch.no_grad():
    function_values = density1(grid_points_tensor,datatime0,device).cpu().detach().numpy().reshape(100, 100)
    #function_values=np.log(function_values)

plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, function_values, levels=100, cmap='rainbow')
plt.colorbar(label='Density value')
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Density Plot on [0, 2.5] x [0, 2.5]')
plt.show()
../_images/notebook_emt_24_0.png
[14]:
from DeepRUOT.train import train_all
device='cpu'
optimizer = torch.optim.SGD(list(f_net.parameters())+list(sf2m_score_model.parameters()),1e-5)

[ ]:
    #new

    if n_local_epochs > 0:
        logger.info(f'Beginning Training')
        for epoch in tqdm(range(1), desc='Training Epoch'):
            l_loss, b_loss, g_loss = train_all(
                f_net, df, groups, optimizer,10,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out, sf2m_score_model=sf2m_score_model,
                 hinge_value=hinge_value,datatime0=datatime0,device=device, lambda_initial=0.1,
                 use_pinn=True, use_penalty=True,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, sigmaa=sigma,lambda_pinn=1,
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

[17]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model_result'))
torch.save(f_net.state_dict(), os.path.join(exp_dir, 'model_result'))
[32]:
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:dim+1].to(device).requires_grad_()
print(data_t0.shape)
x0=data_t0.to(device)
torch.Size([577, 10])
[37]:
import torchsde
class SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"

    def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):
        super().__init__()
        self.drift = ode_drift
        self.score = score
        self.input_size = input_size
        self.sigma = sigma

    # Drift
    def f(self, t, y):
        drift=self.drift(t,y)
        num = y.shape[0]
        t = t.expand(num, 1)
        return drift+self.score.compute_gradient(t,y)

    # Diffusion
    def g(self, t, y):
        return torch.ones_like(y)*sigma


sde = SDE(f_net.v_net, sf2m_score_model, input_size=(dim,), sigma=sigma)
sde_traj = torchsde.sdeint(
        sde,
        x0.to(device),
        dt=0.01,
        ts=torch.linspace(0, n_times - 1, 400, device=device),
    ).cpu()
[38]:
# define sample_number
sample_number = 10

sample_indices = random.sample(range(sde_traj.size(1)), sample_number)

sampled_sde_trajec = sde_traj[:, sample_indices, :]

sampled_sde_trajec.shape


sampled_sde_trajec = sampled_sde_trajec.tolist()

sampled_sde_trajec = np.array(sampled_sde_trajec, dtype=object)

np.save(exp_dir+'/sde_trajec_our_post_plot.npy', sampled_sde_trajec)

[40]:
ts_points=time.to(device)
ts_points

sde_point = torchsde.sdeint(
        sde,
        x0.to(device),
        dt=0.01,
        ts=ts_points,
    ).cpu()
[41]:
sde_point_np = sde_point.detach().numpy()

sde_point_list = sde_point_np.tolist()

sde_point_array = np.array(sde_point_list, dtype=object)

np.save(exp_dir+'/sde_point_our_post.npy', sde_point_array)


[42]:
sde_point_our_post=np.load(exp_dir+'/sde_point_our_post.npy',allow_pickle=True)

[27]:
sde_point_our_post=sde_point_our_post[:,100:300,:]
sde_point_our_post.shape
[27]:
(4, 200, 10)
[43]:
from DeepRUOT.plots import new_plot_comparisions2
[44]:
sde_trajec_our_post_plot=np.load(exp_dir+'/sde_trajec_our_post_plot.npy',allow_pickle=True)
new_plot_comparisions2(
        df, sde_point_our_post, sde_trajec_our_post_plot,
        palette = 'viridis', df_time_key='samples',
        save=True, path=exp_dir, file='sde_trajec_our_post_plot.pdf',
        x='x1', y='x2', z='x3',is_3d=False
    )
[44]:
../_images/notebook_emt_36_0.png