methane_super_emitters.predict

 1import click
 2import torch
 3import lightning as L
 4import numpy as np
 5import glob
 6import os
 7import netCDF4
 8import uuid
 9from joblib import Parallel, delayed
10import datetime
11from methane_super_emitters.model import SuperEmitterDetector
12from methane_super_emitters.dataset import TROPOMISuperEmitterDataset
13from methane_super_emitters.utils import destripe, parse_date, patch_generator
14
15
16def predict_from_tropomi_file(file_path, output_path, model, dataset, threshold):
17    print(f"ANALYZING {file_path}")
18    try:
19        for patch in patch_generator(file_path):
20            if predict(model, dataset, patch) > threshold:
21                print(f"Found emitter in {file_path}")
22                np.savez(os.path.join(output_path, str(uuid.uuid4()) + ".npz"), **patch)
23    except OSError:
24        pass
25
26
27def predict(model, dataset, patch):
28    x = np.array([dataset.normalize(patch)])
29    y_hat = model(torch.tensor(x, dtype=torch.float))
30    y_hat = y_hat.detach().cpu().numpy()
31    return y_hat[0][0]
32
33
34@click.command()
35@click.option("-c", "--checkpoint", help="Checkpoint file")
36@click.option("-d", "--dataset", help="Directory with the dataset")
37@click.option("-i", "--input-dir", help="Directory with files")
38@click.option("-o", "--output-dir", help="Output directory")
39@click.option("-n", "--n-jobs", help="Number of parallel jobs", default=1)
40@click.option(
41    "-t",
42    "--threshold",
43    help="Threshold for the value of the sigmoid output to qualify as a hit",
44    default=0.9,
45)
46def main(checkpoint, dataset, input_dir, output_dir, n_jobs, threshold):
47    model = SuperEmitterDetector.load_from_checkpoint(checkpoint)
48    dataset = TROPOMISuperEmitterDataset(dataset)
49    dataset.unload()
50    for month_path in glob.glob(os.path.join(input_dir, "*")):
51        for day_path in glob.glob(os.path.join(month_path, "*")):
52            Parallel(n_jobs=n_jobs)(
53                delayed(predict_from_tropomi_file)(
54                    file_path, output_dir, model, dataset, threshold
55                )
56                for file_path in glob.glob(os.path.join(day_path, "*.nc"))
57            )
58
59
60if __name__ == "__main__":
61    main()
def predict_from_tropomi_file(file_path, output_path, model, dataset, threshold):
17def predict_from_tropomi_file(file_path, output_path, model, dataset, threshold):
18    print(f"ANALYZING {file_path}")
19    try:
20        for patch in patch_generator(file_path):
21            if predict(model, dataset, patch) > threshold:
22                print(f"Found emitter in {file_path}")
23                np.savez(os.path.join(output_path, str(uuid.uuid4()) + ".npz"), **patch)
24    except OSError:
25        pass
def predict(model, dataset, patch):
28def predict(model, dataset, patch):
29    x = np.array([dataset.normalize(patch)])
30    y_hat = model(torch.tensor(x, dtype=torch.float))
31    y_hat = y_hat.detach().cpu().numpy()
32    return y_hat[0][0]