methane_super_emitters.predict
1import click 2import torch 3import lightning as L 4import numpy as np 5import glob 6import os 7import netCDF4 8import uuid 9from joblib import Parallel, delayed 10import datetime 11from methane_super_emitters.model import SuperEmitterDetector 12from methane_super_emitters.dataset import TROPOMISuperEmitterDataset 13from methane_super_emitters.utils import destripe, parse_date, patch_generator 14 15 16def predict_from_tropomi_file(file_path, output_path, model, dataset, threshold): 17 print(f"ANALYZING {file_path}") 18 try: 19 for patch in patch_generator(file_path): 20 if predict(model, dataset, patch) > threshold: 21 print(f"Found emitter in {file_path}") 22 np.savez(os.path.join(output_path, str(uuid.uuid4()) + ".npz"), **patch) 23 except OSError: 24 pass 25 26 27def predict(model, dataset, patch): 28 x = np.array([dataset.normalize(patch)]) 29 y_hat = model(torch.tensor(x, dtype=torch.float)) 30 y_hat = y_hat.detach().cpu().numpy() 31 return y_hat[0][0] 32 33 34@click.command() 35@click.option("-c", "--checkpoint", help="Checkpoint file") 36@click.option("-d", "--dataset", help="Directory with the dataset") 37@click.option("-i", "--input-dir", help="Directory with files") 38@click.option("-o", "--output-dir", help="Output directory") 39@click.option("-n", "--n-jobs", help="Number of parallel jobs", default=1) 40@click.option( 41 "-t", 42 "--threshold", 43 help="Threshold for the value of the sigmoid output to qualify as a hit", 44 default=0.9, 45) 46def main(checkpoint, dataset, input_dir, output_dir, n_jobs, threshold): 47 model = SuperEmitterDetector.load_from_checkpoint(checkpoint) 48 dataset = TROPOMISuperEmitterDataset(dataset) 49 dataset.unload() 50 for month_path in glob.glob(os.path.join(input_dir, "*")): 51 for day_path in glob.glob(os.path.join(month_path, "*")): 52 Parallel(n_jobs=n_jobs)( 53 delayed(predict_from_tropomi_file)( 54 file_path, output_dir, model, dataset, threshold 55 ) 56 for file_path in glob.glob(os.path.join(day_path, "*.nc")) 57 ) 58 59 60if __name__ == "__main__": 61 main()
def
predict_from_tropomi_file(file_path, output_path, model, dataset, threshold):
17def predict_from_tropomi_file(file_path, output_path, model, dataset, threshold): 18 print(f"ANALYZING {file_path}") 19 try: 20 for patch in patch_generator(file_path): 21 if predict(model, dataset, patch) > threshold: 22 print(f"Found emitter in {file_path}") 23 np.savez(os.path.join(output_path, str(uuid.uuid4()) + ".npz"), **patch) 24 except OSError: 25 pass
def
predict(model, dataset, patch):