[ ]:
import os, random, argparse, pandas as pd, numpy as np, seaborn as sns
from tqdm import tqdm
import torch, torch.nn as nn
# set random seed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
SEED = 10
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)
# load package requirments
from DeepRUOT.losses import MMD_loss, OT_loss1, OT_loss2, Density_loss, Local_density_loss
from DeepRUOT.utils import group_extract, sample, to_np, generate_steps, cal_mass_loss, parser, _valid_criterions
from DeepRUOT.plots import plot_comparision, plot_losses
from DeepRUOT.train import train_un1
from DeepRUOT.models import velocityNet, growthNet, scoreNet, dediffusionNet, indediffusionNet, FNet, ODEFunc2
from DeepRUOT.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from DeepRUOT.exp import setup_exp
from DeepRUOT.eval import generate_plot_data
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint
# for geodesic learning
from scipy.spatial import distance_matrix
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS
[ ]:
import torch.optim as optim
dim=2
f_net = FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu')
import sys
# Simulate the command-line arguments
sys.argv = [
'DeepRUOT Training',
'-d', 'file',
'-c', 'ot1',
'-n', 'simulation_gene',
]
args = parser.parse_args()
opts = vars(args)
# Display the parsed arguments
print(opts)
device = torch.device('cuda')
device
{'dataset': 'file', 'time_col': None, 'name': 'simulation_gene', 'output_dir': '/lustre/home/2301110060/DeepRUOT/results', 'local_epochs': 5, 'epochs': 15, 'local_post_epochs': 5, 'criterion': 'ot1', 'batches': 100, 'cuda': True, 'sample_size': 100, 'sample_with_replacement': False, 'hold_one_out': True, 'hold_out': 'random', 'apply_losses_in_time': True, 'top_k': 5, 'hinge_value': 0.01, 'use_density_loss': True, 'use_local_density': False, 'lambda_density': 1.0, 'lambda_density_local': 1.0, 'lambda_local': 0.2, 'lambda_global': 0.8, 'model_layers': [64], 'use_geo': False, 'geo_layers': [32], 'geo_features': 5, 'n_points': 100, 'n_trajectories': 30, 'n_bins': 100}
device(type='cuda')
[ ]:
df=pd.read_csv(DATA_DIR + '/simulation_gene_data.csv')
# make output dir
if not os.path.isdir(opts['output_dir']):
os.makedirs(opts['output_dir'])
exp_dir, logger = setup_exp(opts['output_dir'], opts, opts['name'])
# load dataset
logger.info(f'Loading dataset')
[ ]:
# setup groups
groups = sorted(df.samples.unique())
steps = generate_steps(groups)
logger.info(f'Defining model')
use_geo = opts['use_geo']
model_layers = opts['model_layers']
model_features = len(df.columns) - 1
logger.info(f'Defining optimizer and criterion')
optimizer = torch.optim.Adam(f_net.parameters())
opts['criterion']='ot1'
criterion = _valid_criterions[opts['criterion']]()
logger.info(f'Extracting parameters')
use_cuda = torch.cuda.is_available() and opts['cuda']
sample_size = (opts['sample_size'], )
sample_with_replacement = opts['sample_with_replacement' ]
apply_losses_in_time = opts['apply_losses_in_time']
n_local_epochs = opts['local_epochs']
n_epochs = opts['epochs']
n_post_local_epochs = opts['local_post_epochs']
n_batches = opts['batches']
hold_one_out = opts['hold_one_out']
hold_out = opts['hold_out']
hinge_value = opts['hinge_value']
top_k = opts['top_k']
lambda_density = opts['lambda_density']
lambda_density_local = opts['lambda_density_local']
use_density_loss = opts['use_density_loss']
use_local_density = opts['use_local_density']
lambda_local = opts['lambda_local']
lambda_global = opts['lambda_global']
n_points=opts['n_points']
n_trajectories=opts['n_trajectories']
n_bins=opts['n_bins']
local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []
[ ]:
f_net=f_net.to(device)
[ ]:
initial_size=df[df['samples']==0].x1.shape[0]
initial_size
400
[ ]:
sample_sizes = df.groupby('samples').size()
ref0 = sample_sizes / sample_sizes.iloc[0]
relative_mass = torch.tensor(ref0.values)
relative_mass
tensor([1.0000, 1.1050, 1.3250, 1.7250, 2.4225], dtype=torch.float64)
[ ]:
sample_size = (df[df['samples']==0.0].values.shape[0],)
PreTrain velocity and growth
[9]:
if n_local_epochs > 0:
logger.info(f'Beginning pretraining')
for epoch in tqdm(range(1), desc='Pretraining Epoch'):
l_loss, b_loss, g_loss = train_un1(
f_net, df, groups, optimizer,50,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out,
hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=1, lambda_energy=0.001,
use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Pretraining Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.6566, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3542, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(2.1351, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5532, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(2.9711, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.4927, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(2.5609, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.6765, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.560936689376831. Model saved.
Otloss
tensor(0.3903, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3421, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.0585, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5650, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.1458, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6276, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.8818, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.2372, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.8817991614341736. Model saved.
Otloss
tensor(0.2439, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2634, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5815, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4181, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.8346, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.4395, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(2.0859, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.9272, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1976, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1006, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4601, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3077, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.7199, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4148, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.4621, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(2.3721, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2114, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0684, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4647, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4900, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5914, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4921, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.6715, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0777, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.671518087387085. Model saved.
Otloss
tensor(0.2013, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1786, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4607, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.8148, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5760, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1450, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4941, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0672, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.49408048391342163. Model saved.
Otloss
tensor(0.1685, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2870, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4080, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5511, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5499, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0775, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5608, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.3952, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1326, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2709, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3172, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5264, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4591, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7240, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5821, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.5826, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1108, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2473, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2525, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3309, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4095, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2930, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.5333, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.5139, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0992, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2477, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2276, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2814, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3976, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4621, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4696, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.9102, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4696376323699951. Model saved.
Otloss
tensor(0.0792, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2140, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1972, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2858, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3659, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3108, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4210, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5532, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.42100802063941956. Model saved.
Otloss
tensor(0.0578, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1842, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1628, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3069, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3167, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3530, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4030, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.3493, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.40304407477378845. Model saved.
Otloss
tensor(0.0473, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1672, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1458, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3135, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2877, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3677, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3631, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6896, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3631232678890228. Model saved.
Otloss
tensor(0.0482, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1007, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1479, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3146, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2978, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3473, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3601, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6455, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3600827753543854. Model saved.
Otloss
tensor(0.0535, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0638, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1508, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2915, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2859, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3108, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3320, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.9096, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.33204078674316406. Model saved.
Otloss
tensor(0.0455, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0728, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1207, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2933, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2302, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2962, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3241, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1246, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.32408589124679565. Model saved.
Otloss
tensor(0.0326, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0672, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1015, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3141, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2014, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3314, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2762, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7233, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2762255370616913. Model saved.
Otloss
tensor(0.0287, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0240, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1038, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3427, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2015, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4695, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2317, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4760, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.23165540397167206. Model saved.
Otloss
tensor(0.0353, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0122, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1082, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2032, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1896, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4542, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2013, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5592, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.20127204060554504. Model saved.
Otloss
tensor(0.0278, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0150, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0577, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1982, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0901, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3278, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1471, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6023, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14711709320545197. Model saved.
Otloss
tensor(0.0187, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0263, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0304, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1660, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0461, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3069, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0924, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6480, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09239967912435532. Model saved.
Otloss
tensor(0.0228, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0131, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0426, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0743, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0575, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1314, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0670, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2547, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.066966712474823. Model saved.
Otloss
tensor(0.0341, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0122, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0756, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0938, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0946, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2522, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0708, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2779, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0190, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0125, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0235, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0592, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0198, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2074, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0965, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4656, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0114, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0243, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0172, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0847, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0404, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2907, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0854, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5300, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0225, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0119, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0527, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0513, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0928, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1386, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1167, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3625, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0376, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0104, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1004, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0669, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1735, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2150, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1889, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4295, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0272, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0144, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0666, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1251, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1044, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2969, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1253, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4006, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0172, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0390, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0537, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2486, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1084, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3151, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1666, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.9502, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0119, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0258, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0478, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2008, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1014, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2950, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1216, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4157, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0189, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0131, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0745, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0863, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1492, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3948, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1558, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7334, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0192, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0117, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0554, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0807, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0808, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2845, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0733, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4458, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0159, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0204, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0348, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1331, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0498, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3097, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0936, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0525, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0156, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0136, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0310, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0502, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0281, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1355, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0184, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3380, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.018364552408456802. Model saved.
Otloss
tensor(0.0248, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0105, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0689, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0445, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1092, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2069, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1021, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6795, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0222, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0101, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0529, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0301, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0528, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0774, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0140, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2663, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.013958828523755074. Model saved.
Otloss
tensor(0.0089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0143, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0157, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0268, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0832, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1482, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1904, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5556, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0093, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0170, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0190, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0438, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0106, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1262, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0257, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3005, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0272, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0087, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0870, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0460, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1673, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1615, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2162, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6354, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0304, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0087, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1076, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0366, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2114, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1064, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.2587, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3294, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0164, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0243, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0663, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1696, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1373, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3170, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1880, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6789, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0147, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0560, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0519, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2739, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1109, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3461, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1682, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.3499, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0072, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0397, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0492, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2590, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1059, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3172, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1311, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5214, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0190, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0181, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0849, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1902, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1654, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6467, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1752, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.3431, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0168, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0109, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0540, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0820, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0899, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3280, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0928, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5471, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0072, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0171, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0291, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0828, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0585, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3695, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1014, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.5029, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0125, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0109, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0151, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0443, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0238, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1898, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0364, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5136, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0078, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0113, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0377, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0462, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0826, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2673, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0955, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.9565, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0130, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0105, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0489, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0354, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0853, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1892, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0758, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6127, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0099, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0135, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0121, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0337, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0237, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2129, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0716, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3514, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [01:53<00:00, 113.27s/it]
[10]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
param.requires_grad = False
[18]:
if n_local_epochs > 0:
logger.info(f'Beginning pretraining')
for epoch in tqdm(range(1), desc='Pretraining Epoch'):
l_loss, b_loss, g_loss = train_un1(
f_net, df, groups, optimizer,30,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out,
hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=0, lambda_energy=0.001,
use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Pretraining Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.0116, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0149, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0160, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0319, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0311, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1446, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0669, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5307, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.06685015559196472. Model saved.
Otloss
tensor(0.0185, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0114, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0609, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0896, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0844, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6742, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0713, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.5237, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0125, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0106, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0266, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0535, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0122, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0483, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0638, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1351, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0637623593211174. Model saved.
Otloss
tensor(0.0091, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0156, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0207, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0428, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0791, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2201, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1708, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.8473, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0146, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0117, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0245, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0262, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0564, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0665, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1391, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0134, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0092, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0397, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0440, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0461, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3878, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0570, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.0994, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.057039402425289154. Model saved.
Otloss
tensor(0.0123, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0087, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0296, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0316, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0180, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0806, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0464, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2004, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.04644349217414856. Model saved.
Otloss
tensor(0.0083, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0134, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0108, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0179, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0486, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0816, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1127, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2577, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0154, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0168, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0231, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0540, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0509, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1073, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1471, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0092, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0090, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0197, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0343, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0238, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1040, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0610, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5178, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0139, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0094, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0377, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0394, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0326, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2071, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0507, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.7066, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0097, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0098, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0163, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0217, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0213, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1025, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0645, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3770, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0090, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0180, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0176, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0247, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0607, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1042, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1212, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3233, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0077, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0123, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0110, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0283, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0400, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0488, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0921, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2194, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0126, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0103, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0323, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0445, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0243, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2267, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0575, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.8180, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0140, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0091, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0353, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0282, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0291, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1469, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0508, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6662, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0097, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0116, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0131, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0284, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0342, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2017, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0847, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3761, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0152, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0256, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0445, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0767, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1165, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1405, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4303, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0091, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0098, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0170, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0499, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0286, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1309, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0750, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4280, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0162, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0098, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0475, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0598, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0559, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4295, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0579, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(2.0272, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0143, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0100, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0330, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0331, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0275, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2246, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0524, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6675, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0099, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0200, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0170, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0255, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0596, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1291, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1230, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4053, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0092, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0138, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0308, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0565, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0872, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1778, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1578, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.8308, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0129, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0115, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0315, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0684, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0245, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2922, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0621, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1553, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0202, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0102, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0642, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0960, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0907, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.9551, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0793, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.1262, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0160, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0152, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0321, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0505, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0282, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2859, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0582, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6879, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0116, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0242, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0315, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0531, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0853, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2232, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1622, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.1060, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0102, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0161, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0344, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0916, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0922, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3067, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1697, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.4423, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0163, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0175, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0455, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0901, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0369, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6302, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0575, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.4646, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0231, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0134, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0758, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1245, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1132, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(1.2625, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1012, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(3.4014, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [01:02<00:00, 62.06s/it]
[19]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
[19]:
<All keys matched successfully>
[13]:
import torch
import matplotlib.pyplot as plt
import numpy as np
def plot_g_values(df, f_net, device=device, output_file='plot.pdf'):
time_points = df['samples'].unique()
data_by_time = {}
for time in time_points:
subset = df[df['samples'] == time]
x = torch.tensor(subset['x1'].values, dtype=torch.float32).to(device)
y = torch.tensor(subset['x2'].values, dtype=torch.float32).to(device)
data = torch.stack([x, y], dim=1)
with torch.no_grad():
t = torch.tensor([time], dtype=torch.float32).to(device)
_, g, _, _ = f_net(t, data)
data_by_time[time] = {'data': subset, 'g_values': g.detach().cpu().numpy()}
all_g_values = np.concatenate([content['g_values'] for content in data_by_time.values()])
vmax_value = np.percentile(all_g_values, 99)
norm = plt.Normalize(vmin=all_g_values.min(), vmax=vmax_value, clip=True)
fig, ax = plt.subplots(figsize=(12, 8))
for time, content in data_by_time.items():
subset = content['data']
g_values = content['g_values']
x = subset['x1']
y = subset['x2']
colors = plt.cm.rainbow(norm(g_values))
ax.scatter(x, y, color=colors, label=f'Time {time}', alpha=0.7, marker='o')
ax.set_xlabel('Gene $X_1$')
ax.set_ylabel('Gene $X_2$')
sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
sm.set_array(all_g_values)
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('Normalized predicted growth rate')
cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{norm(x):.2f}'))
plt.show()
plot_g_values(df, f_net, output_file='gene_growth_pre_post.pdf')
[20]:
# generate plot data
f_net.to('cuda')
generated, trajectories = generate_plot_data(
f_net, df, n_points=400, n_trajectories=50, n_bins=100,
sample_with_replacement=True, use_cuda=use_cuda, samples_key='samples',
logger=logger
)
tensor([0., 1., 2., 3., 4.], device='cuda:0')
<class 'torch.Tensor'>
tensor([0.0000, 0.0404, 0.0808, 0.1212, 0.1616, 0.2020, 0.2424, 0.2828, 0.3232,
0.3636, 0.4040, 0.4444, 0.4848, 0.5253, 0.5657, 0.6061, 0.6465, 0.6869,
0.7273, 0.7677, 0.8081, 0.8485, 0.8889, 0.9293, 0.9697, 1.0101, 1.0505,
1.0909, 1.1313, 1.1717, 1.2121, 1.2525, 1.2929, 1.3333, 1.3737, 1.4141,
1.4545, 1.4949, 1.5354, 1.5758, 1.6162, 1.6566, 1.6970, 1.7374, 1.7778,
1.8182, 1.8586, 1.8990, 1.9394, 1.9798, 2.0202, 2.0606, 2.1010, 2.1414,
2.1818, 2.2222, 2.2626, 2.3030, 2.3434, 2.3838, 2.4242, 2.4646, 2.5051,
2.5455, 2.5859, 2.6263, 2.6667, 2.7071, 2.7475, 2.7879, 2.8283, 2.8687,
2.9091, 2.9495, 2.9899, 3.0303, 3.0707, 3.1111, 3.1515, 3.1919, 3.2323,
3.2727, 3.3131, 3.3535, 3.3939, 3.4343, 3.4747, 3.5152, 3.5556, 3.5960,
3.6364, 3.6768, 3.7172, 3.7576, 3.7980, 3.8384, 3.8788, 3.9192, 3.9596,
4.0000], device='cuda:0')
<class 'torch.Tensor'>
[21]:
plot_comparision(
df, generated, trajectories,
palette = 'viridis', df_time_key='samples',
save=True, path=exp_dir, file='comparision.png',
x='x1', y='x2', is_3d=False
)
[21]:
[22]:
f_net.to(device)
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:3].clone().detach().to(device).requires_grad_()
lnw0 = torch.log(torch.ones(sample_size[0],1,dtype=torch.float32) / (initial_size)).to(device).requires_grad_()
initial_state_energy = (data_t0, lnw0)
t=torch.tensor([0.0,2.0],dtype=torch.float32).requires_grad_()
x_t, lnw_t=odeint(ODEFunc2(f_net),initial_state_energy,t,options=dict(step_size=0.01),method='euler')
final_mass=lnw_t[-1]
final_mass=torch.exp(final_mass)
weight=final_mass/final_mass.sum()
Pretrain score
[ ]:
import pandas as pd
import anndata as ad
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
print("DataFrame shape:", df.shape)
print("DataFrame columns:", df.columns)
n=dim
samples = df['samples'].values
column_names = [f'x{i}' for i in range(1, n + 1)]
obsm_data = df[column_names].values
print("obsm_data shape:", obsm_data.shape)
adata = ad.AnnData(obs=pd.DataFrame(index=samples))
adata.obsm['X_pca'] = obsm_data
adata_loaded = adata
print(adata_loaded)
import scanpy as sc
adata.obs['samples']=df['samples'].values
sc.pl.scatter(adata, basis="pca", color="samples")
n_times = len(adata.obs["samples"].unique())
print(n_times)
X = [
adata.obsm["X_pca"][adata.obs["samples"] == t]
for t in range(n_times)
]
DataFrame shape: (3031, 3)
DataFrame columns: Index(['samples', 'x1', 'x2'], dtype='object')
obsm_data shape: (3031, 2)
AnnData object with n_obs × n_vars = 3031 × 0
obsm: 'X_pca'
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
utils.warn_names_duplicates("obs")
5
[ ]:
from DeepRUOT.utils import OTPlanSampler, ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher,SchrodingerBridgeConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,generate_state_trajectory
from DeepRUOT.models import scoreNet2
batch_size = df[df['samples']==0].x1.shape[0]
sigma = 0.25
time = torch.Tensor(groups)
SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
sf2m_score_model=scoreNet2(in_out_dim=dim, hidden_dim=128, activation='leakyrelu').float().to(device)
sf2m_optimizer = torch.optim.Adam(
list(sf2m_score_model.parameters()), 1e-4
)
trajectory = generate_state_trajectory(X, n_times,batch_size, f_net, time, device)
[27]:
from DeepRUOT.utils import get_batch
max_norm_ut = torch.tensor(0.0)
lambda_penalty=0
for i in tqdm(range(3001)):
sf2m_optimizer.zero_grad()
t, xt, ut,eps = get_batch(SF2M, X, trajectory,batch_size, n_times, return_noise=True)
t=torch.unsqueeze(t,1)
lambda_t = SF2M.compute_lambda(t % 1)
value_st=sf2m_score_model(t, xt)
st = sf2m_score_model.compute_gradient(t, xt)
positive_st = torch.relu(value_st)
penalty = lambda_penalty * torch.max(positive_st)
# max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)
score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)
if i % 100 == 0:
print(torch.max(positive_st))
print(f"{i}: {score_loss.item():0.2f}")
loss = score_loss+penalty
loss.backward()
sf2m_optimizer.step()
0%| | 2/3001 [00:00<03:37, 13.77it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
0: 0.99
3%|▎ | 102/3001 [00:07<03:23, 14.26it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
100: 0.82
7%|▋ | 202/3001 [00:14<03:16, 14.25it/s]
tensor(0.0413, device='cuda:0', grad_fn=<MaxBackward1>)
200: 0.66
10%|█ | 302/3001 [00:21<03:09, 14.23it/s]
tensor(0.2164, device='cuda:0', grad_fn=<MaxBackward1>)
300: 0.60
13%|█▎ | 402/3001 [00:28<03:02, 14.26it/s]
tensor(0.4356, device='cuda:0', grad_fn=<MaxBackward1>)
400: 0.61
17%|█▋ | 502/3001 [00:35<02:55, 14.26it/s]
tensor(0.6048, device='cuda:0', grad_fn=<MaxBackward1>)
500: 0.58
20%|██ | 602/3001 [00:42<02:48, 14.25it/s]
tensor(0.6926, device='cuda:0', grad_fn=<MaxBackward1>)
600: 0.54
23%|██▎ | 702/3001 [00:49<02:41, 14.26it/s]
tensor(0.7098, device='cuda:0', grad_fn=<MaxBackward1>)
700: 0.53
27%|██▋ | 802/3001 [00:57<03:02, 12.07it/s]
tensor(0.5767, device='cuda:0', grad_fn=<MaxBackward1>)
800: 0.53
30%|███ | 902/3001 [01:04<02:27, 14.25it/s]
tensor(0.4652, device='cuda:0', grad_fn=<MaxBackward1>)
900: 0.53
33%|███▎ | 1002/3001 [01:11<02:20, 14.25it/s]
tensor(0.4627, device='cuda:0', grad_fn=<MaxBackward1>)
1000: 0.54
37%|███▋ | 1102/3001 [01:18<02:13, 14.24it/s]
tensor(0.4661, device='cuda:0', grad_fn=<MaxBackward1>)
1100: 0.54
40%|████ | 1202/3001 [01:25<02:06, 14.27it/s]
tensor(0.4728, device='cuda:0', grad_fn=<MaxBackward1>)
1200: 0.55
43%|████▎ | 1302/3001 [01:32<01:59, 14.27it/s]
tensor(0.4222, device='cuda:0', grad_fn=<MaxBackward1>)
1300: 0.52
47%|████▋ | 1402/3001 [01:39<01:52, 14.25it/s]
tensor(0.4192, device='cuda:0', grad_fn=<MaxBackward1>)
1400: 0.54
50%|█████ | 1502/3001 [01:46<01:45, 14.25it/s]
tensor(0.5344, device='cuda:0', grad_fn=<MaxBackward1>)
1500: 0.53
53%|█████▎ | 1602/3001 [01:53<01:38, 14.25it/s]
tensor(0.5400, device='cuda:0', grad_fn=<MaxBackward1>)
1600: 0.54
57%|█████▋ | 1702/3001 [02:00<01:31, 14.25it/s]
tensor(0.5698, device='cuda:0', grad_fn=<MaxBackward1>)
1700: 0.52
60%|██████ | 1802/3001 [02:07<01:24, 14.25it/s]
tensor(0.4072, device='cuda:0', grad_fn=<MaxBackward1>)
1800: 0.50
63%|██████▎ | 1902/3001 [02:14<01:17, 14.23it/s]
tensor(0.1181, device='cuda:0', grad_fn=<MaxBackward1>)
1900: 0.52
67%|██████▋ | 2002/3001 [02:21<01:10, 14.25it/s]
tensor(0.2270, device='cuda:0', grad_fn=<MaxBackward1>)
2000: 0.49
70%|███████ | 2102/3001 [02:28<01:02, 14.27it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2100: 0.53
73%|███████▎ | 2202/3001 [02:35<00:56, 14.27it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2200: 0.48
77%|███████▋ | 2302/3001 [02:42<00:49, 14.25it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2300: 0.47
80%|████████ | 2402/3001 [02:49<00:41, 14.27it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2400: 0.53
83%|████████▎ | 2502/3001 [02:56<00:35, 14.25it/s]
tensor(0.1869, device='cuda:0', grad_fn=<MaxBackward1>)
2500: 0.49
87%|████████▋ | 2602/3001 [03:03<00:28, 14.23it/s]
tensor(0.2245, device='cuda:0', grad_fn=<MaxBackward1>)
2600: 0.55
90%|█████████ | 2702/3001 [03:10<00:20, 14.27it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2700: 0.50
93%|█████████▎| 2802/3001 [03:17<00:13, 14.26it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2800: 0.50
97%|█████████▋| 2902/3001 [03:24<00:06, 14.27it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2900: 0.46
100%|██████████| 3001/3001 [03:31<00:00, 14.21it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
3000: 0.48
Please note that here we need -log s values at each time point do not become negative. If they do, this could lead to exponential blow-up during following train_all. You can inspect this by looking at the output of the score pretraining—e.g., in lines like:
tensor(0., device=’cuda:0’, grad_fn=) 3000: 0.48
If the value here is not 0, that might indicate a problem. To mitigate this, one trick is to adjust the lambda_penalty parameter—try setting it to 1 for some epochs, then switching it to 0 for additional epochs (see our paper C.3 TRAINING INITIAL LOG DENSITY FUNCTION). This can help stabilize training and avoid negative -log s. Or you can choose to set lambda_initial=0 in train_all (this cause the instability).
[28]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model'))
[30]:
import numpy as np
import matplotlib.pyplot as plt
import torch
x_range = np.linspace(0, 2, 100)
y_range = np.linspace(0, 2.5, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)
grid_points_tensor = torch.tensor(grid_points).float().to(device)
expanded_tensor = grid_points_tensor
t_value = 1.0
t_tensor = torch.tensor([t_value] * grid_points.shape[0]).unsqueeze(1).float().to(device)
print(t_tensor.shape)
expanded_tensor.requires_grad_(True)
log_density_values = sf2m_score_model(t_tensor, expanded_tensor)
density_values=torch.exp(log_density_values)
log_density_values.backward(torch.ones_like(log_density_values))
gradients = expanded_tensor.grad
print(gradients.shape)
gradients_np = gradients.cpu().detach().numpy().reshape(100, 100, 2)
step = 5
xv_quiver = xv[::step, ::step]
yv_quiver = yv[::step, ::step]
gradients_np_quiver = gradients_np[::step, ::step, :]
plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, -log_density_values.cpu().detach().numpy().reshape(100, 100), levels=50, cmap='rainbow')
plt.colorbar(label=' ')
plt.quiver(xv_quiver, yv_quiver, gradients_np_quiver[:, :, 0], gradients_np_quiver[:, :, 1], color='white', scale=10)
plt.xlabel('x1')
plt.ylabel('x2')
#plt.savefig('score_t_0_gaussian.pdf')
plt.show()
torch.Size([10000, 1])
torch.Size([10000, 2])
Train
[31]:
sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, 'score_model'),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
param.requires_grad = True
[31]:
FNet(
(v_net): velocityNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): ModuleList(
(0): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
(1-2): 2 x Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
)
(out): Linear(in_features=128, out_features=2, bias=True)
)
(g_net): growthNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(s_net): scoreNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=3, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
(d_net): indediffusionNet(
(activation): LeakyReLU(negative_slope=0.01)
(net): Sequential(
(0): Linear(in_features=1, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=128, out_features=1, bias=True)
)
)
)
[44]:
import numpy as np
from DeepRUOT.utils import density1
import matplotlib.pyplot as plt
import torch
datatime0=torch.zeros(df[df['samples']==0].x1.shape[0],2)
datatime0[:,0]=torch.tensor(df[df['samples']==0].x1)
datatime0[:,1]=torch.tensor(df[df['samples']==0].x2)
device='cpu'
x_range = np.linspace(0, 2, 100)
y_range = np.linspace(0, 2.5, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)
grid_points_tensor = torch.tensor(grid_points).float().to(device)
with torch.no_grad():
function_values = density1(grid_points_tensor,datatime0,device).cpu().detach().numpy().reshape(100, 100)
plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, function_values, levels=100, cmap='rainbow')
plt.colorbar(label='Density value')
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Density Plot on [0, 2.5] x [0, 2.5]')
plt.show()
[45]:
from DeepRUOT.train import train_all
device='cpu'
optimizer = torch.optim.SGD(list(f_net.parameters())+list(sf2m_score_model.parameters()),1e-5)
[46]:
if n_local_epochs > 0:
logger.info(f'Beginning Training')
for epoch in tqdm(range(1), desc='Training Epoch'):
l_loss, b_loss, g_loss = train_all(
f_net, df, groups, optimizer,10,
criterion = criterion, use_cuda = use_cuda,
local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
hold_one_out=hold_one_out, hold_out=hold_out, sf2m_score_model=sf2m_score_model,
hinge_value=hinge_value,datatime0=datatime0,device=device, lambda_initial=0.1,
use_pinn=True, use_penalty=True,use_density_loss=False,lambda_density=10,
top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
sample_with_replacement = sample_with_replacement, logger=logger, sigmaa=sigma,lambda_pinn=1,
)
for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)
Training Epoch: 0%| | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0112, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0667, grad_fn=<SumBackward0>)
mass loss
tensor(0.0469, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0883, grad_fn=<SumBackward0>)
mass loss
tensor(0.1090, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0804, grad_fn=<SumBackward0>)
mass loss
tensor(0.3514, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0112, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0667, grad_fn=<SumBackward0>)
mass loss
tensor(0.0472, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0882, grad_fn=<SumBackward0>)
mass loss
tensor(0.1090, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0803, grad_fn=<SumBackward0>)
mass loss
tensor(0.3497, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0111, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0666, grad_fn=<SumBackward0>)
mass loss
tensor(0.0472, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0881, grad_fn=<SumBackward0>)
mass loss
tensor(0.1084, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0802, grad_fn=<SumBackward0>)
mass loss
tensor(0.3499, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0111, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0666, grad_fn=<SumBackward0>)
mass loss
tensor(0.0472, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0881, grad_fn=<SumBackward0>)
mass loss
tensor(0.1082, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0802, grad_fn=<SumBackward0>)
mass loss
tensor(0.3499, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0111, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0666, grad_fn=<SumBackward0>)
mass loss
tensor(0.0473, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0880, grad_fn=<SumBackward0>)
mass loss
tensor(0.1080, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0801, grad_fn=<SumBackward0>)
mass loss
tensor(0.3506, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0111, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0666, grad_fn=<SumBackward0>)
mass loss
tensor(0.0474, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0879, grad_fn=<SumBackward0>)
mass loss
tensor(0.1085, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0800, grad_fn=<SumBackward0>)
mass loss
tensor(0.3495, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0111, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0665, grad_fn=<SumBackward0>)
mass loss
tensor(0.0479, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0879, grad_fn=<SumBackward0>)
mass loss
tensor(0.1076, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0799, grad_fn=<SumBackward0>)
mass loss
tensor(0.3498, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0110, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0665, grad_fn=<SumBackward0>)
mass loss
tensor(0.0479, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0878, grad_fn=<SumBackward0>)
mass loss
tensor(0.1072, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0799, grad_fn=<SumBackward0>)
mass loss
tensor(0.3483, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0110, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0665, grad_fn=<SumBackward0>)
mass loss
tensor(0.0480, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0878, grad_fn=<SumBackward0>)
mass loss
tensor(0.1069, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0798, grad_fn=<SumBackward0>)
mass loss
tensor(0.3491, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0328, grad_fn=<SumBackward0>)
mass loss
tensor(0.0110, grad_fn=<PowBackward0>)
energy loss
tensor(0.0008, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0664, grad_fn=<SumBackward0>)
mass loss
tensor(0.0484, grad_fn=<PowBackward0>)
energy loss
tensor(0.0020, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0877, grad_fn=<SumBackward0>)
mass loss
tensor(0.1062, grad_fn=<PowBackward0>)
energy loss
tensor(0.0035, grad_fn=<MeanBackward0>)
pinloss
0
Otloss
tensor(0.0797, grad_fn=<SumBackward0>)
mass loss
tensor(0.3487, grad_fn=<PowBackward0>)
energy loss
tensor(0.0055, grad_fn=<MeanBackward0>)
pinloss
0
Training Epoch: 100%|██████████| 1/1 [00:12<00:00, 12.17s/it]
[32]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model_result'))
torch.save(f_net.state_dict(), os.path.join(exp_dir, 'model_result'))
[33]:
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:dim+1].to(device).requires_grad_()
print(data_t0.shape)
x0=data_t0.to(device)
torch.Size([400, 2])
[34]:
import torchsde
class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"
def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):
super().__init__()
self.drift = ode_drift
self.score = score
self.input_size = input_size
self.sigma = sigma
# Drift
def f(self, t, y):
drift=self.drift(t,y)
num = y.shape[0]
t = t.expand(num, 1)
return drift+self.score.compute_gradient(t,y)
# Diffusion
def g(self, t, y):
return torch.ones_like(y)*sigma
sde = SDE(f_net.v_net, sf2m_score_model, input_size=(dim,), sigma=sigma)
sde_traj = torchsde.sdeint(
sde,
x0.to(device),
dt=0.01,
ts=torch.linspace(0, n_times - 1, 400, device=device),
).cpu()
[35]:
# define sample_number
sample_number = 10
sample_indices = random.sample(range(sde_traj.size(1)), sample_number)
sampled_sde_trajec = sde_traj[:, sample_indices, :]
sampled_sde_trajec.shape
sampled_sde_trajec = sampled_sde_trajec.tolist()
sampled_sde_trajec = np.array(sampled_sde_trajec, dtype=object)
np.save(exp_dir+'/sde_trajec_our_post_plot.npy', sampled_sde_trajec)
[36]:
ts_points=time.to(device)
ts_points
sde_point = torchsde.sdeint(
sde,
x0.to(device),
dt=0.01,
ts=ts_points,
).cpu()
[37]:
sde_point_np = sde_point.detach().numpy()
sde_point_list = sde_point_np.tolist()
sde_point_array = np.array(sde_point_list, dtype=object)
np.save(exp_dir+'/sde_point_our_post.npy', sde_point_array)
[38]:
sde_point_our_post=np.load(exp_dir+'/sde_point_our_post.npy',allow_pickle=True)
[39]:
from DeepRUOT.plots import new_plot_comparisions2
[40]:
sde_trajec_our_post_plot=np.load(exp_dir+'/sde_trajec_our_post_plot.npy',allow_pickle=True)
new_plot_comparisions2(
df, sde_point_our_post, sde_trajec_our_post_plot,
palette = 'viridis', df_time_key='samples',
save=True, path=exp_dir, file='sde_trajec_our_post_plot.pdf',
x='x1', y='x2', z='x3',is_3d=False
)
[40]: