[1]:
import os, random, argparse, pandas as pd, numpy as np, seaborn as sns
from tqdm import tqdm
import torch, torch.nn as nn
# set random seed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
SEED = 10
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)
# load package requirments
from DeepRUOT.losses import MMD_loss, OT_loss1, OT_loss2, Density_loss, Local_density_loss
from DeepRUOT.utils import group_extract, sample, to_np, generate_steps, cal_mass_loss, parser, _valid_criterions
from DeepRUOT.plots import plot_comparision, plot_losses
from DeepRUOT.train import train_un1
from DeepRUOT.models import velocityNet, growthNet, scoreNet, dediffusionNet, indediffusionNet, FNet, ODEFunc2
from DeepRUOT.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from DeepRUOT.exp import setup_exp
from DeepRUOT.eval import generate_plot_data
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint
# for geodesic learning
from scipy.spatial import distance_matrix
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS
[2]:
import torch.optim as optim
dim=2
f_net = FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu')
import sys
# Simulate the command-line arguments
sys.argv = [
'DeepRUOT Training',
'-d', 'file',
'-c', 'ot1',
'-n', 'mouse_hematopoiesis',
]
args = parser.parse_args()
opts = vars(args)
# Display the parsed arguments
print(opts)
device = torch.device('cuda') # or 'mps' on MAC
device
{'dataset': 'file', 'time_col': None, 'name': 'mouse_hematopoiesis', 'output_dir': '/lustre/home/2301110060/DeepRUOT/results', 'local_epochs': 5, 'epochs': 15, 'local_post_epochs': 5, 'criterion': 'ot1', 'batches': 100, 'cuda': True, 'sample_size': 100, 'sample_with_replacement': False, 'hold_one_out': True, 'hold_out': 'random', 'apply_losses_in_time': True, 'top_k': 5, 'hinge_value': 0.01, 'use_density_loss': True, 'use_local_density': False, 'lambda_density': 1.0, 'lambda_density_local': 1.0, 'lambda_local': 0.2, 'lambda_global': 0.8, 'model_layers': [64], 'use_geo': False, 'geo_layers': [32], 'geo_features': 5, 'n_points': 100, 'n_trajectories': 30, 'n_bins': 100}
[2]:
device(type='cuda')
[3]:
df=pd.read_csv(DATA_DIR + '/mouse_hematopoiesis.csv')
# make output dir
if not os.path.isdir(opts['output_dir']):
os.makedirs(opts['output_dir'])
exp_dir, logger = setup_exp(opts['output_dir'], opts, opts['name'])
# load dataset
logger.info(f'Loading dataset')
[4]:
# setup groups
groups = sorted(df.samples.unique())
steps = generate_steps(groups)
logger.info(f'Defining model')
use_geo = opts['use_geo']
model_layers = opts['model_layers']
model_features = len(df.columns) - 1
logger.info(f'Defining optimizer and criterion')
optimizer = torch.optim.Adam(f_net.parameters())
opts['criterion']='ot1'
criterion = _valid_criterions[opts['criterion']]()
logger.info(f'Extracting parameters')
use_cuda = torch.cuda.is_available() and opts['cuda']
sample_size = (opts['sample_size'], )
sample_with_replacement = opts['sample_with_replacement' ]
apply_losses_in_time = opts['apply_losses_in_time']
n_local_epochs = opts['local_epochs']
n_epochs = opts['epochs']
n_post_local_epochs = opts['local_post_epochs']
n_batches = opts['batches']
hold_one_out = opts['hold_one_out']
hold_out = opts['hold_out']
hinge_value = opts['hinge_value']
top_k = opts['top_k']
lambda_density = opts['lambda_density']
lambda_density_local = opts['lambda_density_local']
use_density_loss = opts['use_density_loss']
use_local_density = opts['use_local_density']
lambda_local = opts['lambda_local']
lambda_global = opts['lambda_global']
n_points=opts['n_points']
n_trajectories=opts['n_trajectories']
n_bins=opts['n_bins']
local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []
[5]:
f_net=f_net.to(device)
[6]:
initial_size=df[df['samples']==0].x1.shape[0]
initial_size
[6]:
1429
[7]:
sample_sizes = df.groupby('samples').size()
ref0 = sample_sizes / sample_sizes.iloc[0]
relative_mass = torch.tensor(ref0.values)
relative_mass
[7]:
tensor([1.0000, 2.6459, 4.0504], dtype=torch.float64)
[8]:
sample_size = (df[df['samples']==0.0].values.shape[0],)
PreTrain velocity and growth
[9]:
if n_local_epochs > 0:
logger.info(f'Beginning pretraining')
for epoch in tqdm(range(1), desc='Pretraining Epoch'):
l_loss, b_loss, g_loss = train_un1(
f_net, df, groups, optimizer,30,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out,
hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=1, lambda_energy=0.001,
use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Pretraining Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/ot/lp/__init__.py:388: UserWarning: numItermax reached before optimality. Try to increase numItermax.
result_code_string = check_result(result_code)
Otloss
tensor(1.1143, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3382, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.0771, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.5962, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.0771148204803467. Model saved.
Otloss
tensor(1.0209, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2941, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.8717, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1971, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.8717283606529236. Model saved.
Otloss
tensor(0.8197, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2450, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.7380, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.4149, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.7379590272903442. Model saved.
Otloss
tensor(0.6309, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2298, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4812, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0960, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.48116445541381836. Model saved.
Otloss
tensor(0.4669, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2594, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1875, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4089199900627136. Model saved.
Otloss
tensor(0.3641, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3158, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.6296, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.2053, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3334, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3524, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.7308, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.7692, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3318, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3459, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.7268, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.5117, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3376, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3319, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.6004, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.2332, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3511, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3092, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4357, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.8135, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3687, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2679, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3463, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6991, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.34631145000457764. Model saved.
Otloss
tensor(0.3841, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2297, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3083, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6724, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3082798719406128. Model saved.
Otloss
tensor(0.3912, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2052, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2930, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6644, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2929557263851166. Model saved.
Otloss
tensor(0.3838, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1885, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2756, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6024, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.27556586265563965. Model saved.
Otloss
tensor(0.3619, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2012, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2679, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5370, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2679210305213928. Model saved.
Otloss
tensor(0.3315, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2330, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2853, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5522, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3018, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2452, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3082, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6128, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2832, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2389, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3362, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6046, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2715, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2674, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3379, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6357, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2715, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2837, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3249, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6257, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2720, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2748, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2709, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5766, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2798, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2453, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2699, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5618, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2803, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2546, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2433, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6724, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.24325743317604065. Model saved.
Otloss
tensor(0.2837, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2623, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2487, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7271, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2773, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2628, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2349, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7190, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2349427342414856. Model saved.
Otloss
tensor(0.2687, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2507, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2491, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6757, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2614, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2548, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2532, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6341, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2525, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2832, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2267, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6657, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.22673574090003967. Model saved.
Otloss
tensor(0.2531, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2977, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2337, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7177, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2442, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2413, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2256, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6467, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [01:49<00:00, 109.63s/it]
New minimum otloss found: 0.22561463713645935. Model saved.
[10]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
param.requires_grad = False
[16]:
if n_local_epochs > 0:
logger.info(f'Beginning pretraining')
for epoch in tqdm(range(1), desc='Pretraining Epoch'):
l_loss, b_loss, g_loss = train_un1(
f_net, df, groups, optimizer,30,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out,
hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=0, lambda_energy=0.001,
use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Pretraining Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.0806, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0523, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1259, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1119, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.12586644291877747. Model saved.
Otloss
tensor(0.0798, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0460, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1580, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1433, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0701, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0532, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0693, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1053, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.06929072737693787. Model saved.
Otloss
tensor(0.0812, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0773, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0773, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1207, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0617, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0421, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0859, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0982, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0779, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0456, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1192, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1052, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0490, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0320, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0489, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1026, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.04894893616437912. Model saved.
Otloss
tensor(0.0751, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0499, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0839, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1155, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0450, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0339, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0294, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1010, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.02942836657166481. Model saved.
Otloss
tensor(0.0618, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0337, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1261, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1318, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0447, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0284, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0605, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1208, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0561, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0415, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0722, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1051, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0529, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0348, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0540, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1136, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0506, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0252, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0985, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1593, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0531, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0258, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0922, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1139, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0438, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0311, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1222, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1248, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0623, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0429, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1063, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0945, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0373, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0246, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0610, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1919, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0709, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0222, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1643, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2844, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0396, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0211, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0607, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1159, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0549, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0385, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1572, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1413, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0709, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0432, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1286, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1084, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0286, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0234, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0420, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1798, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0737, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0223, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1674, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3411, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0516, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0218, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0721, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1961, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0323, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0295, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0847, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1124, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0744, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0490, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1579, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1719, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0324, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0292, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0446, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0986, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0416, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0204, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0826, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2405, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0639, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0221, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1219, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2624, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [01:40<00:00, 100.71s/it]
[17]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
[17]:
FNet(
(v_net): velocityNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): ModuleList(
(0): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
(1-2): 2 x Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
)
(out): Linear(in_features=128, out_features=2, bias=True)
)
(g_net): growthNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(s_net): scoreNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(d_net): indediffusionNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=1, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
)
[12]:
import torch
import matplotlib.pyplot as plt
import numpy as np
def plot_g_values(df, f_net, device=device, output_file='plot.pdf'):
time_points = df['samples'].unique()
data_by_time = {}
for time in time_points:
subset = df[df['samples'] == time]
x = torch.tensor(subset['x1'].values, dtype=torch.float32).to(device)
y = torch.tensor(subset['x2'].values, dtype=torch.float32).to(device)
data = torch.stack([x, y], dim=1)
with torch.no_grad():
t = torch.tensor([time], dtype=torch.float32).to(device)
_, g, _, _ = f_net(t, data)
data_by_time[time] = {'data': subset, 'g_values': g.detach().cpu().numpy()}
all_g_values = np.concatenate([content['g_values'] for content in data_by_time.values()])
vmax_value = np.percentile(all_g_values, 99)
norm = plt.Normalize(vmin=all_g_values.min(), vmax=vmax_value, clip=True)
fig, ax = plt.subplots(figsize=(12, 8))
for time, content in data_by_time.items():
subset = content['data']
g_values = content['g_values']
x = subset['x1']
y = subset['x2']
colors = plt.cm.rainbow(norm(g_values))
ax.scatter(x, y, color=colors, label=f'Time {time}', alpha=0.7, marker='o')
ax.set_xlabel('Gene $X_1$')
ax.set_ylabel('Gene $X_2$')
sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
sm.set_array(all_g_values)
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('Normalized predicted growth rate')
cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{norm(x):.2f}'))
plt.show()
plot_g_values(df, f_net, output_file='gene_growth_pre_post.pdf')
[19]:
# generate plot data
generated, trajectories = generate_plot_data(
f_net, df, n_points=500, n_trajectories=50, n_bins=100,
sample_with_replacement=True, use_cuda=use_cuda, samples_key='samples',
logger=logger
)
tensor([0., 1., 2.], device='cuda:0')
<class 'torch.Tensor'>
tensor([0.0000, 0.0202, 0.0404, 0.0606, 0.0808, 0.1010, 0.1212, 0.1414, 0.1616,
0.1818, 0.2020, 0.2222, 0.2424, 0.2626, 0.2828, 0.3030, 0.3232, 0.3434,
0.3636, 0.3838, 0.4040, 0.4242, 0.4444, 0.4646, 0.4848, 0.5051, 0.5253,
0.5455, 0.5657, 0.5859, 0.6061, 0.6263, 0.6465, 0.6667, 0.6869, 0.7071,
0.7273, 0.7475, 0.7677, 0.7879, 0.8081, 0.8283, 0.8485, 0.8687, 0.8889,
0.9091, 0.9293, 0.9495, 0.9697, 0.9899, 1.0101, 1.0303, 1.0505, 1.0707,
1.0909, 1.1111, 1.1313, 1.1515, 1.1717, 1.1919, 1.2121, 1.2323, 1.2525,
1.2727, 1.2929, 1.3131, 1.3333, 1.3535, 1.3737, 1.3939, 1.4141, 1.4343,
1.4545, 1.4747, 1.4949, 1.5152, 1.5354, 1.5556, 1.5758, 1.5960, 1.6162,
1.6364, 1.6566, 1.6768, 1.6970, 1.7172, 1.7374, 1.7576, 1.7778, 1.7980,
1.8182, 1.8384, 1.8586, 1.8788, 1.8990, 1.9192, 1.9394, 1.9596, 1.9798,
2.0000], device='cuda:0')
<class 'torch.Tensor'>
[20]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
np.random.seed(42)
def plot_combined_data(df, generated):
cmap = plt.get_cmap('viridis')
colors = cmap(np.linspace(0, 1, generated.shape[0]))
plt.figure(figsize=(12, 8))
for i, label in enumerate(df['samples'].unique()):
subset = df[df['samples'] == label]
x = subset['x1']
y = subset['x2']
plt.scatter(x, y, label=f'Time {label} (df)', alpha=0.1, color=colors[i], marker='X')
for i in range(generated.shape[0]):
x = generated[i, :, 0]
y = generated[i, :, 1]
plt.scatter(x, y, color=colors[i], label=f'Time {i} (generated)', alpha=0.7)
plt.xlim(min(df['x1'].min(), generated[:, :, 0].min()), max(df['x1'].max(), generated[:, :, 0].max()))
plt.ylim(min(df['x2'].min(), generated[:, :, 1].min()), max(df['x2'].max(), generated[:, :, 1].max()))
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('Combined Scatter Plot of df and Generated Data')
plt.show()
plot_combined_data(df, generated)
[21]:
plot_comparision(
df, generated, trajectories,
palette = 'viridis', df_time_key='samples',
save=True, path=exp_dir, file='comparision.png',
x='x1', y='x2', is_3d=False
)
[21]:
[22]:
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:3].clone().detach().to(device).requires_grad_()
lnw0 = torch.log(torch.ones(sample_size[0],1,dtype=torch.float32) / (initial_size)).to(device).requires_grad_()
initial_state_energy = (data_t0, lnw0)
t=torch.tensor([0.0,2.0],dtype=torch.float32).requires_grad_()
x_t, lnw_t=odeint(ODEFunc2(f_net),initial_state_energy,t,options=dict(step_size=0.01),method='euler')
final_mass=lnw_t[-1]
final_mass=torch.exp(final_mass)
weight=final_mass/final_mass.sum()
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:306: UserWarning: t is not on the same device as y0. Coercing to y0.device.
warnings.warn("t is not on the same device as y0. Coercing to y0.device.")
Pretrain score
[10]:
import pandas as pd
import anndata as ad
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
print("DataFrame shape:", df.shape)
print("DataFrame columns:", df.columns)
n=dim
samples = df['samples'].values
column_names = [f'x{i}' for i in range(1, n + 1)]
obsm_data = df[column_names].values
print("obsm_data shape:", obsm_data.shape)
adata = ad.AnnData(obs=pd.DataFrame(index=samples))
adata.obsm['X_pca'] = obsm_data
adata_loaded = adata
print(adata_loaded)
import scanpy as sc
adata.obs['samples']=df['samples'].values
sc.pl.scatter(adata, basis="pca", color="samples")
n_times = len(adata.obs["samples"].unique())
print(n_times)
X = [
adata.obsm["X_pca"][adata.obs["samples"] == t]
for t in range(n_times)
]
DataFrame shape: (10998, 3)
DataFrame columns: Index(['samples', 'x1', 'x2'], dtype='object')
obsm_data shape: (10998, 2)
AnnData object with n_obs × n_vars = 10998 × 0
obsm: 'X_pca'
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
utils.warn_names_duplicates("obs")
3
[11]:
from DeepRUOT.utils import OTPlanSampler, ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher,SchrodingerBridgeConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,generate_state_trajectory
from DeepRUOT.models import scoreNet2
batch_size = df[df['samples']==0].x1.shape[0]
sigma = 0.25
time = torch.Tensor(groups)
SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
sf2m_score_model=scoreNet2(in_out_dim=dim, hidden_dim=128, activation='leakyrelu').float().to(device)
sf2m_optimizer = torch.optim.Adam(
list(sf2m_score_model.parameters()), 1e-4
)
trajectory = generate_state_trajectory(X, n_times,batch_size, f_net, time, device)
[22]:
from DeepRUOT.utils import get_batch
max_norm_ut = torch.tensor(0.0)
lambda_penalty=0
for i in tqdm(range(11)):
sf2m_optimizer.zero_grad()
t, xt, ut,eps = get_batch(SF2M, X, trajectory,batch_size, n_times, return_noise=True)
t=torch.unsqueeze(t,1)
lambda_t = SF2M.compute_lambda(t % 1)
value_st=sf2m_score_model(t, xt)
st = sf2m_score_model.compute_gradient(t, xt)
positive_st = torch.relu(value_st)
#penalty = lambda_penalty * torch.max(positive_st)
# max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)
score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)
if i % 10 == 0:
print(torch.max(positive_st))
print(f"{i}: {score_loss.item():0.2f}")
loss = score_loss
loss.backward()
sf2m_optimizer.step() #use mps may be faster
9%|▉ | 1/11 [00:00<00:04, 2.07it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
0: 0.90
100%|██████████| 11/11 [00:05<00:00, 2.10it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
10: 0.88
[23]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model'))
[16]:
import numpy as np
import matplotlib.pyplot as plt
import torch
x_range = np.linspace(-2, 1.5, 100)
y_range = np.linspace(-1.5, 2, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)
grid_points_tensor = torch.tensor(grid_points).float().to(device)
expanded_tensor = grid_points_tensor
t_value = 2.0
t_tensor = torch.tensor([t_value] * grid_points.shape[0]).unsqueeze(1).float().to(device)
print(t_tensor.shape)
expanded_tensor.requires_grad_(True)
log_density_values = sf2m_score_model(t_tensor, expanded_tensor)
density_values=torch.exp(log_density_values)
log_density_values.backward(torch.ones_like(log_density_values))
gradients = expanded_tensor.grad
print(gradients.shape)
gradients_np = gradients.cpu().detach().numpy().reshape(100, 100, 2)
step = 5
xv_quiver = xv[::step, ::step]
yv_quiver = yv[::step, ::step]
gradients_np_quiver = gradients_np[::step, ::step, :]
plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, -log_density_values.cpu().detach().numpy().reshape(100, 100), levels=50, cmap='rainbow')
plt.colorbar(label=' ')
plt.quiver(xv_quiver, yv_quiver, gradients_np_quiver[:, :, 0], gradients_np_quiver[:, :, 1], color='white', scale=5)
plt.xlabel('x1')
plt.ylabel('x2')
#plt.savefig('score_t_0_gaussian.pdf')
plt.show()
torch.Size([10000, 1])
torch.Size([10000, 2])
Train
[12]:
sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, 'score_model'),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
param.requires_grad = True
[12]:
FNet(
(v_net): velocityNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): ModuleList(
(0): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
(1-2): 2 x Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
)
(out): Linear(in_features=128, out_features=2, bias=True)
)
(g_net): growthNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(s_net): scoreNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(d_net): indediffusionNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=1, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
)
[13]:
import numpy as np
from DeepRUOT.utils import density1
import matplotlib.pyplot as plt
import torch
datatime0=torch.zeros(df[df['samples']==0].x1.shape[0],2)
datatime0[:,0]=torch.tensor(df[df['samples']==0].x1)
datatime0[:,1]=torch.tensor(df[df['samples']==0].x2)
device='cpu'
x_range = np.linspace(-2, 1.5, 100)
y_range = np.linspace(-1.5, 2, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)
grid_points_tensor = torch.tensor(grid_points).float().to(device)
with torch.no_grad():
function_values = density1(grid_points_tensor,datatime0,device).cpu().detach().numpy().reshape(100, 100)
plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, function_values, levels=100, cmap='rainbow')
plt.colorbar(label='Density value')
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Density Plot on [0, 2.5] x [0, 2.5]')
plt.show()
[14]:
from DeepRUOT.train import train_all
device='cpu'
optimizer = torch.optim.SGD(list(f_net.parameters())+list(sf2m_score_model.parameters()),1e-5)
[15]:
#new
if n_local_epochs > 0:
logger.info(f'Beginning Training')
for epoch in tqdm(range(1), desc='Training Epoch'):
l_loss, b_loss, g_loss = train_all(
f_net, df, groups, optimizer,10,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out, sf2m_score_model=sf2m_score_model,
hinge_value=hinge_value,datatime0=datatime0,device=device, lambda_initial=0.1,
use_pinn=True, use_penalty=True,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, sigmaa=sigma,lambda_pinn=1,
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Training Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.0665, grad_fn=<SumBackward0>)
mass loss
tensor(0.0414, grad_fn=<PowBackward0>)
energy loss
tensor(0.0039, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/ot/lp/__init__.py:388: UserWarning: numItermax reached before optimality. Try to increase numItermax.
result_code_string = check_result(result_code)
Otloss
tensor(0.1336, grad_fn=<SumBackward0>)
mass loss
tensor(0.1319, grad_fn=<PowBackward0>)
energy loss
tensor(0.0112, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9953e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0645, grad_fn=<SumBackward0>)
mass loss
tensor(0.0416, grad_fn=<PowBackward0>)
energy loss
tensor(0.0036, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1234, grad_fn=<SumBackward0>)
mass loss
tensor(0.1146, grad_fn=<PowBackward0>)
energy loss
tensor(0.0099, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9655e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0636, grad_fn=<SumBackward0>)
mass loss
tensor(0.0420, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1195, grad_fn=<SumBackward0>)
mass loss
tensor(0.1078, grad_fn=<PowBackward0>)
energy loss
tensor(0.0091, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9466e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0632, grad_fn=<SumBackward0>)
mass loss
tensor(0.0418, grad_fn=<PowBackward0>)
energy loss
tensor(0.0034, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1178, grad_fn=<SumBackward0>)
mass loss
tensor(0.1057, grad_fn=<PowBackward0>)
energy loss
tensor(0.0086, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9336e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0631, grad_fn=<SumBackward0>)
mass loss
tensor(0.0418, grad_fn=<PowBackward0>)
energy loss
tensor(0.0033, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1172, grad_fn=<SumBackward0>)
mass loss
tensor(0.1050, grad_fn=<PowBackward0>)
energy loss
tensor(0.0083, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9236e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0630, grad_fn=<SumBackward0>)
mass loss
tensor(0.0415, grad_fn=<PowBackward0>)
energy loss
tensor(0.0032, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1172, grad_fn=<SumBackward0>)
mass loss
tensor(0.1039, grad_fn=<PowBackward0>)
energy loss
tensor(0.0080, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9159e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0630, grad_fn=<SumBackward0>)
mass loss
tensor(0.0412, grad_fn=<PowBackward0>)
energy loss
tensor(0.0032, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1174, grad_fn=<SumBackward0>)
mass loss
tensor(0.1037, grad_fn=<PowBackward0>)
energy loss
tensor(0.0077, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9099e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0631, grad_fn=<SumBackward0>)
mass loss
tensor(0.0407, grad_fn=<PowBackward0>)
energy loss
tensor(0.0031, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1178, grad_fn=<SumBackward0>)
mass loss
tensor(0.1036, grad_fn=<PowBackward0>)
energy loss
tensor(0.0076, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9052e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0631, grad_fn=<SumBackward0>)
mass loss
tensor(0.0408, grad_fn=<PowBackward0>)
energy loss
tensor(0.0031, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1184, grad_fn=<SumBackward0>)
mass loss
tensor(0.1034, grad_fn=<PowBackward0>)
energy loss
tensor(0.0074, grad_fn=<MeanBackward0>)
pinloss
tensor(6.9014e-06, grad_fn=<MeanBackward0>)
Otloss
tensor(0.0632, grad_fn=<SumBackward0>)
mass loss
tensor(0.0408, grad_fn=<PowBackward0>)
energy loss
tensor(0.0030, grad_fn=<MeanBackward0>)
pinloss
tensor(0.0001, grad_fn=<MeanBackward0>)
Otloss
tensor(0.1188, grad_fn=<SumBackward0>)
mass loss
tensor(0.1036, grad_fn=<PowBackward0>)
energy loss
tensor(0.0073, grad_fn=<MeanBackward0>)
pinloss
tensor(6.8985e-06, grad_fn=<MeanBackward0>)
Training Epoch: 100%|██████████| 1/1 [00:39<00:00, 39.25s/it]
[16]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model_result'))
torch.save(f_net.state_dict(), os.path.join(exp_dir, 'model_result'))
[17]:
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:dim+1].to(device).requires_grad_()
print(data_t0.shape)
x0=data_t0.to(device)
torch.Size([1429, 2])
[27]:
import torchsde
class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"
def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):
super().__init__()
self.drift = ode_drift
self.score = score
self.input_size = input_size
self.sigma = sigma
# Drift
def f(self, t, y):
drift=self.drift(t,y)
num = y.shape[0]
t = t.expand(num, 1)
return drift+self.score.compute_gradient(t,y)
# Diffusion
def g(self, t, y):
return torch.ones_like(y)*sigma
sde = SDE(f_net.v_net, sf2m_score_model, input_size=(dim,), sigma=sigma)
sde_traj = torchsde.sdeint(
sde,
x0.to(device),
dt=0.01,
ts=torch.linspace(0, n_times - 1, 400, device=device),
).cpu()
[28]:
# define sample_number
sample_number = 10
sample_indices = random.sample(range(sde_traj.size(1)), sample_number)
sampled_sde_trajec = sde_traj[:, sample_indices, :]
sampled_sde_trajec.shape
sampled_sde_trajec = sampled_sde_trajec.tolist()
sampled_sde_trajec = np.array(sampled_sde_trajec, dtype=object)
np.save(exp_dir+'/sde_trajec_our_post_plot.npy', sampled_sde_trajec)
[21]:
ts_points=time
ts_points
sde_point = torchsde.sdeint(
sde,
x0.to(device),
dt=0.01,
ts=ts_points,
).cpu()
[22]:
sde_point_np = sde_point.detach().numpy()
sde_point_list = sde_point_np.tolist()
sde_point_array = np.array(sde_point_list, dtype=object)
np.save(exp_dir+'/sde_point_our_post.npy', sde_point_array)
[23]:
sde_point_our_post=np.load(exp_dir+'/sde_point_our_post.npy',allow_pickle=True)
[24]:
sde_point_our_post=sde_point_our_post[:,100:300,:]
sde_point_our_post.shape
[24]:
(3, 200, 2)
[25]:
from DeepRUOT.plots import new_plot_comparisions2
[29]:
sde_trajec_our_post_plot=np.load(exp_dir+'/sde_trajec_our_post_plot.npy',allow_pickle=True)
new_plot_comparisions2(
df, sde_point_our_post, sde_trajec_our_post_plot,
palette = 'viridis', df_time_key='samples',
save=True, path=exp_dir, file='sde_trajec_our_post_plot.pdf',
x='x1', y='x2', z='x3',is_3d=False
)
[29]: