[1]:
import os, random, argparse, pandas as pd, numpy as np, seaborn as sns
from tqdm import tqdm
import torch, torch.nn as nn


# set random seed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
SEED = 10
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)


# load package requirments
from DeepRUOT.losses import MMD_loss, OT_loss1, OT_loss2, Density_loss, Local_density_loss
from DeepRUOT.utils import group_extract, sample, to_np, generate_steps, cal_mass_loss, parser, _valid_criterions
from DeepRUOT.plots import plot_comparision, plot_losses
from DeepRUOT.train import train_un1
from DeepRUOT.models import velocityNet, growthNet, scoreNet, dediffusionNet, indediffusionNet, FNet, ODEFunc2
from DeepRUOT.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from DeepRUOT.exp import setup_exp
from DeepRUOT.eval import generate_plot_data
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint

# for geodesic learning
from scipy.spatial import distance_matrix
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS
[2]:
import torch.optim as optim

dim=20
f_net = FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu')


import sys

# Simulate the command-line arguments
sys.argv = [
    'DeepRUOT Training',
    '-d', 'file',
    '-c', 'ot1',
    '-n', 'guassian_20d',
]

args = parser.parse_args()
opts = vars(args)

# Display the parsed arguments
print(opts)

device = torch.device('cuda')
device

{'dataset': 'file', 'time_col': None, 'name': 'guassian_20d', 'output_dir': '/lustre/home/2301110060/DeepRUOT/results', 'local_epochs': 5, 'epochs': 15, 'local_post_epochs': 5, 'criterion': 'ot1', 'batches': 100, 'cuda': True, 'sample_size': 100, 'sample_with_replacement': False, 'hold_one_out': True, 'hold_out': 'random', 'apply_losses_in_time': True, 'top_k': 5, 'hinge_value': 0.01, 'use_density_loss': True, 'use_local_density': False, 'lambda_density': 1.0, 'lambda_density_local': 1.0, 'lambda_local': 0.2, 'lambda_global': 0.8, 'model_layers': [64], 'use_geo': False, 'geo_layers': [32], 'geo_features': 5, 'n_points': 100, 'n_trajectories': 30, 'n_bins': 100}
[2]:
device(type='cuda')
[3]:
df=pd.read_csv(DATA_DIR + '/gaussian_20d.csv')

if not os.path.isdir(opts['output_dir']):
    os.makedirs(opts['output_dir'])
exp_dir, logger = setup_exp(opts['output_dir'], opts, opts['name'])
  # load dataset
logger.info(f'Loading dataset')

[4]:
 # setup groups
groups = sorted(df.samples.unique())
steps = generate_steps(groups)
logger.info(f'Defining model')
use_geo = opts['use_geo']
model_layers = opts['model_layers']
model_features = len(df.columns) - 1

logger.info(f'Defining optimizer and criterion')
optimizer = torch.optim.Adam(f_net.parameters())

opts['criterion']='ot1'
criterion =  _valid_criterions[opts['criterion']]()

logger.info(f'Extracting parameters')
use_cuda = torch.cuda.is_available() and opts['cuda']
sample_size = (opts['sample_size'], )
sample_with_replacement = opts['sample_with_replacement' ]
apply_losses_in_time = opts['apply_losses_in_time']


n_local_epochs = opts['local_epochs']
n_epochs = opts['epochs']
n_post_local_epochs = opts['local_post_epochs']
n_batches = opts['batches']

hold_one_out = opts['hold_one_out']
hold_out = opts['hold_out']

hinge_value = opts['hinge_value']
top_k = opts['top_k']
lambda_density = opts['lambda_density']
lambda_density_local = opts['lambda_density_local']
use_density_loss = opts['use_density_loss']
use_local_density = opts['use_local_density']

lambda_local = opts['lambda_local']
lambda_global = opts['lambda_global']

n_points=opts['n_points']
n_trajectories=opts['n_trajectories']
n_bins=opts['n_bins']


local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []
[5]:
f_net=f_net.to(device)
f_net
[5]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=21, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=20, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[6]:
initial_size=df[df['samples']==0].x1.shape[0]
initial_size
[6]:
500
[7]:
sample_sizes = df.groupby('samples').size()
ref0 = sample_sizes / sample_sizes.iloc[0]
relative_mass = torch.tensor(ref0.values)
relative_mass
[7]:
tensor([1.0000, 2.8000], dtype=torch.float64)
[8]:
sample_size = (df[df['samples']==0.0].values.shape[0],)

PreTrain velocity and growth

[9]:
    if n_local_epochs > 0:
        logger.info(f'Beginning pretraining')
        for epoch in tqdm(range(1), desc='Pretraining Epoch'):
            l_loss, b_loss, g_loss = train_un1(
                f_net, df, groups, optimizer,30,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out,
                 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=1, lambda_energy=0.001,
                 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

Pretraining Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(3.4642, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3492, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 3.4642062187194824. Model saved.
Otloss
tensor(3.3166, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3815, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 3.3165838718414307. Model saved.
Otloss
tensor(3.1762, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4329, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 3.176236152648926. Model saved.
Otloss
tensor(3.0371, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5090, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 3.0371193885803223. Model saved.
Otloss
tensor(2.8961, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5274, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.896106004714966. Model saved.
Otloss
tensor(2.7507, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5149, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.750743865966797. Model saved.
Otloss
tensor(2.5998, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4947, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.5997862815856934. Model saved.
Otloss
tensor(2.4441, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4838, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.444066047668457. Model saved.
Otloss
tensor(2.2831, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4680, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.283076763153076. Model saved.
Otloss
tensor(2.1132, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4615, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 2.1132123470306396. Model saved.
Otloss
tensor(1.9340, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4943, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.933964729309082. Model saved.
Otloss
tensor(1.7474, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5413, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.7473881244659424. Model saved.
Otloss
tensor(1.5573, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6060, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.557271957397461. Model saved.
Otloss
tensor(1.3706, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6492, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.370648741722107. Model saved.
Otloss
tensor(1.1939, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6510, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.1938560009002686. Model saved.
Otloss
tensor(1.0322, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.6326, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 1.0321886539459229. Model saved.
Otloss
tensor(0.9261, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5990, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.9260845184326172. Model saved.
Otloss
tensor(0.9111, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.5292, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.9110905528068542. Model saved.
Otloss
tensor(0.9310, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4738, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.9632, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.4151, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.9936, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3809, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.0128, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3375, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.0171, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3102, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(1.0057, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.3005, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.9815, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2952, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.9480, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2942, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.9084, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2913, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.9084140658378601. Model saved.
Otloss
tensor(0.8643, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2907, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.8642805814743042. Model saved.
Otloss
tensor(0.8180, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2964, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.8179771900177002. Model saved.
Otloss
tensor(0.7726, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2995, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [00:21<00:00, 21.40s/it]
New minimum otloss found: 0.7725681066513062. Model saved.

[10]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
    param.requires_grad = False

[13]:
    if n_local_epochs > 0:
        logger.info(f'Beginning pretraining')
        for epoch in tqdm(range(1), desc='Pretraining Epoch'):
            l_loss, b_loss, g_loss = train_un1(
                f_net, df, groups, optimizer,250,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out,
                 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=0, lambda_energy=0.001,
                 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, device=device,best_model_path=exp_dir+'/best_model'
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

Pretraining Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
begin local loss
Otloss
tensor(0.7331, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2986, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.7331304550170898. Model saved.
Otloss
tensor(0.7212, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2921, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.7211790084838867. Model saved.
Otloss
tensor(0.7078, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2838, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.7078102827072144. Model saved.
Otloss
tensor(0.6936, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2730, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6936277151107788. Model saved.
Otloss
tensor(0.6787, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2611, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6786981821060181. Model saved.
Otloss
tensor(0.6630, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2519, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.663044810295105. Model saved.
Otloss
tensor(0.6467, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2435, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6467463970184326. Model saved.
Otloss
tensor(0.6302, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2381, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6301634311676025. Model saved.
Otloss
tensor(0.6134, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2352, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.6134482622146606. Model saved.
Otloss
tensor(0.5964, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2298, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5964141488075256. Model saved.
Otloss
tensor(0.5792, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2281, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5792098045349121. Model saved.
Otloss
tensor(0.5620, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2256, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5620136260986328. Model saved.
Otloss
tensor(0.5450, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2260, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5449652075767517. Model saved.
Otloss
tensor(0.5282, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2270, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5282033085823059. Model saved.
Otloss
tensor(0.5118, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2274, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.5117811560630798. Model saved.
Otloss
tensor(0.4957, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2255, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4956590533256531. Model saved.
Otloss
tensor(0.4799, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2265, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4798792600631714. Model saved.
Otloss
tensor(0.4645, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2286, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4645465314388275. Model saved.
Otloss
tensor(0.4498, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2295, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4497615098953247. Model saved.
Otloss
tensor(0.4356, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2285, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.43557965755462646. Model saved.
Otloss
tensor(0.4221, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2277, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.42214110493659973. Model saved.
Otloss
tensor(0.4094, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2238, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.4094284176826477. Model saved.
Otloss
tensor(0.3975, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2263, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3974587023258209. Model saved.
Otloss
tensor(0.3863, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2253, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3862609267234802. Model saved.
Otloss
tensor(0.3757, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2288, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.37572383880615234. Model saved.
Otloss
tensor(0.3658, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2280, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3658371567726135. Model saved.
Otloss
tensor(0.3572, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2257, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.35719695687294006. Model saved.
Otloss
tensor(0.3622, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2242, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3676, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2235, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3731, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2253, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3786, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2255, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3839, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2224, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3890, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2197, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3937, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2181, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3979, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2125, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4015, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2064, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4047, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2024, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4073, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.2016, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4093, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1956, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4109, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1916, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4122, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1891, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4131, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1874, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4137, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1861, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4140, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1868, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4140, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1877, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4137, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1861, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4134, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1831, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4132, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1820, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4129, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1803, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4123, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1768, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4115, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1761, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4103, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1782, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4089, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1776, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4072, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1766, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4049, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1751, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.4019, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1732, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3982, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1729, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3938, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1732, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3886, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1723, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3826, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1717, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3759, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1723, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3685, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1752, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3604, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1748, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.3512, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1741, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3512299358844757. Model saved.
Otloss
tensor(0.3409, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1714, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.34086132049560547. Model saved.
Otloss
tensor(0.3291, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1641, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.32905030250549316. Model saved.
Otloss
tensor(0.3158, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1637, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.31582117080688477. Model saved.
Otloss
tensor(0.3011, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1705, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.3011268973350525. Model saved.
Otloss
tensor(0.2863, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1781, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.28631168603897095. Model saved.
Otloss
tensor(0.2736, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1848, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2735878825187683. Model saved.
Otloss
tensor(0.2627, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1916, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2627209722995758. Model saved.
Otloss
tensor(0.2516, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1806, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2515758275985718. Model saved.
Otloss
tensor(0.2402, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1507, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.2402287870645523. Model saved.
Otloss
tensor(0.2287, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.1192, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.22868338227272034. Model saved.
Otloss
tensor(0.2182, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0884, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.21823400259017944. Model saved.
Otloss
tensor(0.2099, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0743, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.20994432270526886. Model saved.
Otloss
tensor(0.2043, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0726, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.20429208874702454. Model saved.
Otloss
tensor(0.1990, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0686, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.19898957014083862. Model saved.
Otloss
tensor(0.1931, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0649, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.19308674335479736. Model saved.
Otloss
tensor(0.1871, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0591, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1870669722557068. Model saved.
Otloss
tensor(0.1818, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0589, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.18183431029319763. Model saved.
Otloss
tensor(0.1787, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0565, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.17872890830039978. Model saved.
Otloss
tensor(0.1782, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0557, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.17815589904785156. Model saved.
Otloss
tensor(0.1789, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0578, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1793, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0575, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1781, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0532, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.178134024143219. Model saved.
Otloss
tensor(0.1755, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0530, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1754778027534485. Model saved.
Otloss
tensor(0.1729, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0527, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.17287935316562653. Model saved.
Otloss
tensor(0.1711, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0505, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.17110714316368103. Model saved.
Otloss
tensor(0.1703, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0500, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.170296311378479. Model saved.
Otloss
tensor(0.1704, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0518, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1709, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0515, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1716, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0505, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1723, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0495, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1730, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0493, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1735, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0497, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1733, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0517, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1724, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0549, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1711, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0557, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1702, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0524, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.17022699117660522. Model saved.
Otloss
tensor(0.1704, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0536, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1714, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0545, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1725, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0535, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1732, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0540, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1731, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0548, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1725, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0561, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1718, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0544, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1710, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0542, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1706, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0538, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1706, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0525, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1702, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0535, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1701781451702118. Model saved.
Otloss
tensor(0.1691, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0540, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16906479001045227. Model saved.
Otloss
tensor(0.1677, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0532, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16768932342529297. Model saved.
Otloss
tensor(0.1667, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0542, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1667085886001587. Model saved.
Otloss
tensor(0.1662, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0552, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16616453230381012. Model saved.
Otloss
tensor(0.1656, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0549, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1656438410282135. Model saved.
Otloss
tensor(0.1649, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0552, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.164910688996315. Model saved.
Otloss
tensor(0.1638, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0562, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16382265090942383. Model saved.
Otloss
tensor(0.1627, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0558, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16272491216659546. Model saved.
Otloss
tensor(0.1621, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0544, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16214200854301453. Model saved.
Otloss
tensor(0.1617, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0523, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16168060898780823. Model saved.
Otloss
tensor(0.1603, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0509, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.16030660271644592. Model saved.
Otloss
tensor(0.1585, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0514, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15852797031402588. Model saved.
Otloss
tensor(0.1575, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0526, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15752866864204407. Model saved.
Otloss
tensor(0.1571, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0529, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15708552300930023. Model saved.
Otloss
tensor(0.1563, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0540, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15627995133399963. Model saved.
Otloss
tensor(0.1550, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0557, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15501251816749573. Model saved.
Otloss
tensor(0.1542, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0549, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15422654151916504. Model saved.
Otloss
tensor(0.1537, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0550, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15365956723690033. Model saved.
Otloss
tensor(0.1528, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0551, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15275242924690247. Model saved.
Otloss
tensor(0.1514, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0559, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15143916010856628. Model saved.
Otloss
tensor(0.1507, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0552, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.15069769322872162. Model saved.
Otloss
tensor(0.1498, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0554, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14983034133911133. Model saved.
Otloss
tensor(0.1481, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0543, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14812082052230835. Model saved.
Otloss
tensor(0.1471, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0541, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14707276225090027. Model saved.
Otloss
tensor(0.1453, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0551, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14530467987060547. Model saved.
Otloss
tensor(0.1433, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0559, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14328408241271973. Model saved.
Otloss
tensor(0.1417, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0550, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.14167150855064392. Model saved.
Otloss
tensor(0.1394, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0554, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1393733024597168. Model saved.
Otloss
tensor(0.1381, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0553, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.13814082741737366. Model saved.
Otloss
tensor(0.1357, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0536, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.13566741347312927. Model saved.
Otloss
tensor(0.1357, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0522, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1295, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0540, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.12948617339134216. Model saved.
Otloss
tensor(0.1372, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0529, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1225, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0564, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.12251230329275131. Model saved.
Otloss
tensor(0.1289, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0522, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1270, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0511, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1174, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0522, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.11744575202465057. Model saved.
Otloss
tensor(0.1230, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0489, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1240, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0491, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1153, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0487, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.11532987654209137. Model saved.
Otloss
tensor(0.1208, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0471, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1210, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0468, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1137, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0472, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.11372380703687668. Model saved.
Otloss
tensor(0.1216, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0465, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1134, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0479, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.11344195157289505. Model saved.
Otloss
tensor(0.1112, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0486, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.11122848093509674. Model saved.
Otloss
tensor(0.1188, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0465, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1083, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0497, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.108262799680233. Model saved.
Otloss
tensor(0.1167, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0439, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1070, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0466, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.10695388913154602. Model saved.
Otloss
tensor(0.1052, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0495, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.10524171590805054. Model saved.
Otloss
tensor(0.1239, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0405, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1031, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0459, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.1030769944190979. Model saved.
Otloss
tensor(0.1026, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0518, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.10256614536046982. Model saved.
Otloss
tensor(0.1157, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0406, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1115, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0404, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0981, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0512, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09812656044960022. Model saved.
Otloss
tensor(0.0992, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0499, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0977, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0419, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09769077599048615. Model saved.
Otloss
tensor(0.1079, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0374, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1049, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0394, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0963, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0462, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0963270366191864. Model saved.
Otloss
tensor(0.0937, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0480, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0936737209558487. Model saved.
Otloss
tensor(0.0974, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0422, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1064, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0377, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1086, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0362, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1010, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0391, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0934, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0461, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09341716021299362. Model saved.
Otloss
tensor(0.0929, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0464, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09289030730724335. Model saved.
Otloss
tensor(0.1020, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0381, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1063, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0368, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0997, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0380, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0977, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0431, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0948, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0394, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0970, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0393, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0982, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0444, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0939, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0349, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1024, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0355, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0933, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0353, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0963, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0395, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0932, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0368, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0976, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0342, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0943, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0333, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0928, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0342, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09276648610830307. Model saved.
Otloss
tensor(0.0930, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0345, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0921, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0325, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0921473428606987. Model saved.
Otloss
tensor(0.0943, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0307, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0925, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0313, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0911, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0315, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09113812446594238. Model saved.
Otloss
tensor(0.0914, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0317, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0908, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0306, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.09075680375099182. Model saved.
Otloss
tensor(0.0929, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0310, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0921, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0297, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0898, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0309, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08983145654201508. Model saved.
Otloss
tensor(0.0894, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0298, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08939330279827118. Model saved.
Otloss
tensor(0.0903, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0297, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0926, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0282, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0915, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0269, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0900, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0274, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0900, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0305, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0901, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0303, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0902, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0303, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0904, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0301, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0903, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0291, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0893, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0287, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08929537236690521. Model saved.
Otloss
tensor(0.0880, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0278, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.0879611074924469. Model saved.
Otloss
tensor(0.0870, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0280, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08699000626802444. Model saved.
Otloss
tensor(0.0882, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0273, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0924, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0261, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0968, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0252, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0980, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0250, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0911, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0252, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0891, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0291, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0895, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0286, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0860, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0295, device='cuda:0', grad_fn=<PowBackward0>)
New minimum otloss found: 0.08603987097740173. Model saved.
Otloss
tensor(0.0872, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0300, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0914, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0295, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0977, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0310, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1010, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0291, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0972, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0255, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0878, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0231, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0918, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0241, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0902, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0217, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1035, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0262, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0886, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0318, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1086, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0334, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1016, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0226, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1153, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0343, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1147, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0302, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0995, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0266, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1181, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0550, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1308, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0457, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1030, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0284, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0970, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0411, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1343, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0651, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1187, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0464, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.0969, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0340, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1152, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0467, device='cuda:0', grad_fn=<PowBackward0>)
Otloss
tensor(0.1214, device='cuda:0', grad_fn=<SumBackward0>)
mass loss
tensor(0.0581, device='cuda:0', grad_fn=<PowBackward0>)
Pretraining Epoch: 100%|██████████| 1/1 [02:30<00:00, 150.46s/it]
[14]:
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
[14]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=21, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=20, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[15]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_g_values(df, f_net, device=device, output_file='plot.pdf'):
    time_points = df['samples'].unique()

    data_by_time = {}

    for time in time_points:
        subset = df[df['samples'] == time]
        n = dim

        column_names = [f'x{i}' for i in range(1, n + 1)]

        tensors = [torch.tensor(subset[col].values, dtype=torch.float32).to(device) for col in column_names]

        data = torch.stack(tensors, dim=1)
        with torch.no_grad():
            t = torch.tensor([time], dtype=torch.float32).to(device)
            _, g, _, _ = f_net(t, data)

        data_by_time[time] = {'data': subset, 'g_values': g.detach().cpu().numpy()}

    all_g_values = np.concatenate([content['g_values'] for content in data_by_time.values()])

    vmax_value = np.percentile(all_g_values, 99)

    norm = plt.Normalize(vmin=all_g_values.min(), vmax=vmax_value, clip=True)

    fig, ax = plt.subplots(figsize=(12, 8))

    for time, content in data_by_time.items():
        subset = content['data']
        g_values = content['g_values']
        x = subset['x1']
        y = subset['x2']

        colors = plt.cm.rainbow(norm(g_values))

        ax.scatter(x, y, color=colors, label=f'Time {time}', alpha=0.7, marker='o')

    ax.set_xlabel('Gene $X_1$')
    ax.set_ylabel('Gene $X_2$')

    sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
    sm.set_array(all_g_values)
    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label('Normalized predicted growth rate')

    cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{norm(x):.2f}'))

    plt.show()

plot_g_values(df, f_net, output_file='gene_growth_pre_post.pdf')
../_images/notebook_gaussian_20d_13_0.png
[16]:
    # generate plot data
    f_net.to('cuda')
    generated, trajectories = generate_plot_data(
        f_net, df, n_points=400, n_trajectories=40, n_bins=100,
        sample_with_replacement=True, use_cuda=use_cuda, samples_key='samples',
        logger=logger
    )

tensor([0., 1.], device='cuda:0')
<class 'torch.Tensor'>
tensor([0.0000, 0.0101, 0.0202, 0.0303, 0.0404, 0.0505, 0.0606, 0.0707, 0.0808,
        0.0909, 0.1010, 0.1111, 0.1212, 0.1313, 0.1414, 0.1515, 0.1616, 0.1717,
        0.1818, 0.1919, 0.2020, 0.2121, 0.2222, 0.2323, 0.2424, 0.2525, 0.2626,
        0.2727, 0.2828, 0.2929, 0.3030, 0.3131, 0.3232, 0.3333, 0.3434, 0.3535,
        0.3636, 0.3737, 0.3838, 0.3939, 0.4040, 0.4141, 0.4242, 0.4343, 0.4444,
        0.4545, 0.4646, 0.4747, 0.4848, 0.4949, 0.5051, 0.5152, 0.5253, 0.5354,
        0.5455, 0.5556, 0.5657, 0.5758, 0.5859, 0.5960, 0.6061, 0.6162, 0.6263,
        0.6364, 0.6465, 0.6566, 0.6667, 0.6768, 0.6869, 0.6970, 0.7071, 0.7172,
        0.7273, 0.7374, 0.7475, 0.7576, 0.7677, 0.7778, 0.7879, 0.7980, 0.8081,
        0.8182, 0.8283, 0.8384, 0.8485, 0.8586, 0.8687, 0.8788, 0.8889, 0.8990,
        0.9091, 0.9192, 0.9293, 0.9394, 0.9495, 0.9596, 0.9697, 0.9798, 0.9899,
        1.0000], device='cuda:0')
<class 'torch.Tensor'>
[18]:
    plot_comparision(
        df, generated, trajectories,
        palette = 'viridis', df_time_key='samples',
        save=True, path=exp_dir, file='comparision.png',
        x='x1', y='x2', is_3d=False
    )
[18]:
../_images/notebook_gaussian_20d_15_0.png

Pretrain score

[19]:
import pandas as pd
import anndata as ad

f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)

print("DataFrame shape:", df.shape)
print("DataFrame columns:", df.columns)
n=dim
samples = df['samples'].values
column_names = [f'x{i}' for i in range(1, n + 1)]

obsm_data = df[column_names].values
print("obsm_data shape:", obsm_data.shape)

adata = ad.AnnData(obs=pd.DataFrame(index=samples))


adata.obsm['X_pca'] = obsm_data


adata_loaded = adata


print(adata_loaded)

import scanpy as sc

adata.obs['samples']=df['samples'].values
sc.pl.scatter(adata, basis="pca", color="samples")

n_times = len(adata.obs["samples"].unique())
print(n_times)
X = [
    adata.obsm["X_pca"][adata.obs["samples"] == t]
    for t in range(n_times)
]
DataFrame shape: (1900, 21)
DataFrame columns: Index(['samples', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10',
       'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20'],
      dtype='object')
obsm_data shape: (1900, 20)
AnnData object with n_obs × n_vars = 1900 × 0
    obsm: 'X_pca'
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
/lustre/home/2301110060/software/miniconda3/envs/DeepRUOT/lib/python3.10/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
../_images/notebook_gaussian_20d_17_2.png
2
[20]:
from DeepRUOT.utils import OTPlanSampler, ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher,SchrodingerBridgeConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,generate_state_trajectory
from DeepRUOT.models import scoreNet2

batch_size = df[df['samples']==0].x1.shape[0]
sigma = 0.1
time = torch.Tensor(groups)
SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
sf2m_score_model=scoreNet2(in_out_dim=dim, hidden_dim=128,  activation='leakyrelu').float().to(device)
sf2m_optimizer = torch.optim.Adam(
    list(sf2m_score_model.parameters()), 1e-4
)
trajectory = generate_state_trajectory(X, n_times,batch_size, f_net, time, device)
[21]:
from DeepRUOT.utils import get_batch

max_norm_ut = torch.tensor(0.0)
lambda_penalty=0
for i in tqdm(range(3001)):
    sf2m_optimizer.zero_grad()
    t, xt, ut,eps = get_batch(SF2M, X, trajectory,batch_size, n_times, return_noise=True)
    t=torch.unsqueeze(t,1)
    lambda_t = SF2M.compute_lambda(t % 1)
    value_st=sf2m_score_model(t, xt)
    st = sf2m_score_model.compute_gradient(t, xt)
    positive_st = torch.relu(value_st)
    penalty = lambda_penalty * torch.max(positive_st)

    # max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)
    score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)
    if i % 100 == 0:
        print(torch.max(positive_st))
        print(f"{i}:  {score_loss.item():0.2f}")
    loss = score_loss+penalty

    loss.backward()
    sf2m_optimizer.step()
  0%|          | 6/3001 [00:00<01:51, 26.83it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
0:  1.01
  3%|▎         | 105/3001 [00:03<01:46, 27.22it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
100:  0.96
  7%|▋         | 204/3001 [00:07<01:42, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
200:  0.89
 10%|█         | 306/3001 [00:11<01:38, 27.28it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
300:  0.85
 13%|█▎        | 405/3001 [00:14<01:35, 27.22it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
400:  0.84
 17%|█▋        | 504/3001 [00:18<01:31, 27.28it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
500:  0.84
 20%|██        | 606/3001 [00:22<01:27, 27.25it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
600:  0.81
 23%|██▎       | 705/3001 [00:25<01:24, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
700:  0.84
 27%|██▋       | 804/3001 [00:29<01:20, 27.23it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
800:  0.81
 30%|███       | 906/3001 [00:33<01:16, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
900:  0.82
 33%|███▎      | 1005/3001 [00:36<01:13, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1000:  0.80
 37%|███▋      | 1104/3001 [00:40<01:09, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1100:  0.79
 40%|████      | 1206/3001 [00:44<01:05, 27.31it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1200:  0.79
 43%|████▎     | 1305/3001 [00:47<01:02, 27.28it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1300:  0.80
 47%|████▋     | 1404/3001 [00:51<00:58, 27.31it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1400:  0.80
 50%|█████     | 1506/3001 [00:55<00:54, 27.25it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1500:  0.79
 53%|█████▎    | 1605/3001 [00:58<00:51, 27.31it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1600:  0.80
 57%|█████▋    | 1704/3001 [01:02<00:47, 27.30it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1700:  0.81
 60%|██████    | 1806/3001 [01:06<00:43, 27.30it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1800:  0.79
 63%|██████▎   | 1905/3001 [01:09<00:40, 27.32it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
1900:  0.80
 67%|██████▋   | 2004/3001 [01:13<00:36, 27.18it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2000:  0.77
 70%|███████   | 2106/3001 [01:17<00:32, 27.31it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2100:  0.80
 73%|███████▎  | 2205/3001 [01:20<00:29, 27.21it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2200:  0.79
 77%|███████▋  | 2304/3001 [01:24<00:25, 27.30it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2300:  0.79
 80%|████████  | 2406/3001 [01:28<00:21, 27.34it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2400:  0.79
 83%|████████▎ | 2505/3001 [01:31<00:18, 27.31it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2500:  0.78
 87%|████████▋ | 2604/3001 [01:35<00:14, 27.32it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2600:  0.79
 90%|█████████ | 2706/3001 [01:39<00:10, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2700:  0.78
 93%|█████████▎| 2805/3001 [01:42<00:07, 27.29it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2800:  0.80
 97%|█████████▋| 2904/3001 [01:46<00:03, 27.16it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
2900:  0.78
100%|██████████| 3001/3001 [01:49<00:00, 27.30it/s]
tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)
3000:  0.78

[22]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model'))
[23]:
import numpy as np
import matplotlib.pyplot as plt
import torch



x_range = np.linspace(-1.5, 1.5, 100)
y_range = np.linspace(-0.5, 3, 100)
xv, yv = np.meshgrid(x_range, y_range)
grid_points = np.stack([xv, yv], axis=-1).reshape(-1, 2)

grid_points_tensor = torch.tensor(grid_points).float().to(device)
padding_dim=dim-2
padding_tensor = torch.zeros(10000, padding_dim).to(device)

expanded_tensor = torch.cat((grid_points_tensor, padding_tensor), dim=1)

t_value = 1.0
t_tensor = torch.tensor([t_value] * grid_points.shape[0]).unsqueeze(1).float().to(device)
print(t_tensor.shape)

expanded_tensor.requires_grad_(True)
log_density_values = sf2m_score_model(t_tensor, expanded_tensor)

density_values=torch.exp(log_density_values)

log_density_values.backward(torch.ones_like(log_density_values))
gradients = expanded_tensor.grad
print(gradients.shape)
gradients_np = gradients.cpu().detach().numpy().reshape(100, 100, dim)

step = 5
xv_quiver = xv[::step, ::step]
yv_quiver = yv[::step, ::step]
gradients_np_quiver = gradients_np[::step, ::step, :]

plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, -log_density_values.cpu().detach().numpy().reshape(100, 100), levels=50, cmap='rainbow')
plt.colorbar(label=' ')
plt.quiver(xv_quiver, yv_quiver, gradients_np_quiver[:, :, 0], gradients_np_quiver[:, :, 1], color='white', scale=5)
plt.xlabel('x1')
plt.ylabel('x2')
#plt.savefig('score_t_0_gaussian.pdf')
plt.show()
torch.Size([10000, 1])
torch.Size([10000, 20])
../_images/notebook_gaussian_20d_21_1.png

Train

[12]:
sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, 'score_model'),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
f_net.load_state_dict(torch.load(os.path.join(exp_dir+'/best_model'),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
    param.requires_grad = True
[12]:
FNet(
  (v_net): velocityNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=21, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1-2): 2 x Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
    (out): Linear(in_features=128, out_features=20, bias=True)
  )
  (g_net): growthNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (s_net): scoreNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=21, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (d_net): indediffusionNet(
    (activation): LeakyReLU(negative_slope=0.01)
    (net): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
[13]:
import numpy as np
from DeepRUOT.utils import density1
import matplotlib.pyplot as plt
import torch

datatime0=torch.zeros(df[df['samples']==0].x1.shape[0],dim)
datatime0[:, 0] = torch.tensor(df[df['samples'] == 0].x1.values, dtype=torch.float32)
datatime0[:, 1] = torch.tensor(df[df['samples'] == 0].x2.values, dtype=torch.float32)
device='cpu'

x_range = np.linspace(-1.5, 1.5, 100)
y_range = np.linspace(-0.5, 3, 100)
xv, yv = np.meshgrid(x_range, y_range)
num_grid_points = xv.size  # Total number of grid points (10000)

# Initialize a grid with 10 dimensions: first two vary, others are zero
grid_points = np.zeros((num_grid_points, dim), dtype=np.float32)
grid_points[:, 0] = xv.flatten()
grid_points[:, 1] = yv.flatten()

# Convert grid points to a PyTorch tensor and move to the appropriate device
grid_points_tensor = torch.tensor(grid_points).float().to(device)

grid_points_tensor = torch.tensor(grid_points).float().to(device)

with torch.no_grad():
    function_values = density1(grid_points_tensor,datatime0,device).cpu().detach().numpy().reshape(100, 100)
    #function_values=np.log(function_values)

plt.figure(figsize=(8, 6))
plt.contourf(xv, yv, function_values, levels=100, cmap='rainbow')
plt.colorbar(label='Density value')
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Density Plot on [0, 2.5] x [0, 2.5]')
plt.show()
../_images/notebook_gaussian_20d_24_0.png
[14]:
from DeepRUOT.train import train_all
device='cpu'
optimizer = torch.optim.SGD(list(f_net.parameters())+list(sf2m_score_model.parameters()),1e-5)

[ ]:
    #new

    if n_local_epochs > 0:
        logger.info(f'Beginning Training')
        for epoch in tqdm(range(1), desc='Training Epoch'):
            l_loss, b_loss, g_loss = train_all(
                f_net, df, groups, optimizer,10,
                criterion = criterion, use_cuda = use_cuda,
                local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
                hold_one_out=hold_one_out, hold_out=hold_out, sf2m_score_model=sf2m_score_model,
                 hinge_value=hinge_value,datatime0=datatime0,device=device, lambda_initial=0.1,
                 use_pinn=True, use_penalty=True,use_density_loss=False,lambda_density=10,
                top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
                sample_with_replacement = sample_with_replacement, logger=logger, sigmaa=sigma,lambda_pinn=1,
            )
            for k, v in l_loss.items():
                local_losses[k].extend(v)
            batch_losses.extend(b_loss)
            globe_losses.extend(g_loss)

[17]:
torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, 'score_model_result'))
torch.save(f_net.state_dict(), os.path.join(exp_dir, 'model_result'))
[24]:
data=torch.tensor(df[df['samples']==0].values,dtype=torch.float32).requires_grad_()
data_t0 = data[:, 1:dim+1].to(device).requires_grad_()
print(data_t0.shape)
x0=data_t0.to(device)
torch.Size([500, 20])
[25]:
import torchsde
class SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"

    def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):
        super().__init__()
        self.drift = ode_drift
        self.score = score
        self.input_size = input_size
        self.sigma = sigma

    # Drift
    def f(self, t, y):
        drift=self.drift(t,y)
        num = y.shape[0]
        t = t.expand(num, 1)
        return drift+self.score.compute_gradient(t,y)

    # Diffusion
    def g(self, t, y):
        return torch.ones_like(y)*sigma


sde = SDE(f_net.v_net, sf2m_score_model, input_size=(dim,), sigma=sigma)
sde_traj = torchsde.sdeint(
        sde,
        x0.to(device),
        dt=0.01,
        ts=torch.linspace(0, n_times - 1, 400, device=device),
    ).cpu()
[26]:
# define sample_number
sample_number = 10

sample_indices = random.sample(range(sde_traj.size(1)), sample_number)

sampled_sde_trajec = sde_traj[:, sample_indices, :]

sampled_sde_trajec.shape


sampled_sde_trajec = sampled_sde_trajec.tolist()

sampled_sde_trajec = np.array(sampled_sde_trajec, dtype=object)

np.save(exp_dir+'/sde_trajec_our_post_plot.npy', sampled_sde_trajec)

[27]:
ts_points=time.to(device)
ts_points

sde_point = torchsde.sdeint(
        sde,
        x0.to(device),
        dt=0.01,
        ts=ts_points,
    ).cpu()
[28]:
sde_point_np = sde_point.detach().numpy()

sde_point_list = sde_point_np.tolist()

sde_point_array = np.array(sde_point_list, dtype=object)

np.save(exp_dir+'/sde_point_our_post.npy', sde_point_array)


[42]:
sde_point_our_post=np.load(exp_dir+'/sde_point_our_post.npy',allow_pickle=True)

[30]:
from DeepRUOT.plots import new_plot_comparisions2
[43]:
sde_trajec_our_post_plot=np.load(exp_dir+'/sde_trajec_our_post_plot.npy',allow_pickle=True)
new_plot_comparisions2(
        df, sde_point_our_post, sde_trajec_our_post_plot,
        palette = 'viridis', df_time_key='samples',
        save=True, path=exp_dir, file='sde_trajec_our_post_plot.pdf',
        x='x1', y='x2', z='x3',is_3d=False
    )
[43]:
../_images/notebook_gaussian_20d_35_0.png