methane_super_emitters.optimize
1import click 2import lightning as L 3import optuna 4from methane_super_emitters.model import SuperEmitterDetector 5from methane_super_emitters.datamodule import TROPOMISuperEmitterDataModule 6 7 8def optimize_model(input_dir, max_epochs, n_trials): 9 def objective(trial): 10 fields = ["methane", "u10", "v10", "qa"] 11 dropout_rate = trial.suggest_float("dropout", 0.1, 0.9) 12 weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2) 13 learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2) 14 n = 1 15 result = 0.0 16 for _ in range(n): 17 model = SuperEmitterDetector(fields=fields, dropout=dropout_rate, weight_decay=weight_decay, lr=learning_rate) 18 datamodule = TROPOMISuperEmitterDataModule(input_dir, fields=fields) 19 trainer = L.Trainer(max_epochs=max_epochs) 20 trainer.fit(model=model, datamodule=datamodule) 21 result += trainer.callback_metrics["val_acc"].item() 22 return result / n 23 study = optuna.create_study(direction="maximize") 24 study.optimize(objective, n_trials=n_trials) 25 df = study.trials_dataframe() 26 df.to_csv("opt_results.csv") 27 print("Best parameters:", study.best_params) 28 29 30@click.command() 31@click.option("-i", "--input-dir", help="Data directory") 32@click.option("-m", "--max-epochs", help="Maximum number of epochs", default=100) 33@click.option( 34 "-n", "--n-trials", help="Number of trials or points to sample", default=200 35) 36def optimize_model_(input_dir, max_epochs, n_trials): 37 return optimize_model(input_dir, max_epochs, n_trials) 38 39 40if __name__ == "__main__": 41 optimize_model_()
def
optimize_model(input_dir, max_epochs, n_trials):
9def optimize_model(input_dir, max_epochs, n_trials): 10 def objective(trial): 11 fields = ["methane", "u10", "v10", "qa"] 12 dropout_rate = trial.suggest_float("dropout", 0.1, 0.9) 13 weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2) 14 learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2) 15 n = 1 16 result = 0.0 17 for _ in range(n): 18 model = SuperEmitterDetector(fields=fields, dropout=dropout_rate, weight_decay=weight_decay, lr=learning_rate) 19 datamodule = TROPOMISuperEmitterDataModule(input_dir, fields=fields) 20 trainer = L.Trainer(max_epochs=max_epochs) 21 trainer.fit(model=model, datamodule=datamodule) 22 result += trainer.callback_metrics["val_acc"].item() 23 return result / n 24 study = optuna.create_study(direction="maximize") 25 study.optimize(objective, n_trials=n_trials) 26 df = study.trials_dataframe() 27 df.to_csv("opt_results.csv") 28 print("Best parameters:", study.best_params)