methane_super_emitters.cam

 1import torch
 2import numpy as np
 3import matplotlib.pyplot as plt
 4import cv2
 5from pytorch_grad_cam import GradCAM
 6from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
 7from methane_super_emitters.dataset_stats import normalize
 8import click
 9
10@click.command()
11@click.option('-i', '--input-file', help='Input file')
12@click.option('-o', '--output-file', help='Output file')
13@click.option('-c', '--checkpoint', help='Model checkpoint file')
14def main(input_file, output_file, checkpoint):
15    model = SuperEmitterDetector.load_from_checkpoint(checkpoint)
16    model.eval()
17    target_layer = model.conv_layers[-1]
18    cam = GradCAM(model=model, target_layers=[target_layer])
19    data = np.load(input_file)
20    example_image = normalize(data, ['methane', 'u10', 'v10', 'qa'])
21    example_image = torch.tensor(np.array([example_image]), dtype=torch.float)
22    input_tensor = example_image.requires_grad_(True)
23    grayscale_cam = cam(input_tensor=input_tensor)[0]
24    grayscale_cam = np.maximum(grayscale_cam, 0)
25    grayscale_cam = grayscale_cam / grayscale_cam.max()
26    input_image_np = example_image[0].mean(0).cpu().numpy()
27    input_image_np = cv2.normalize(input_image_np, None, 0, 1, cv2.NORM_MINMAX)
28    visualization = show_cam_on_image(input_image_np, grayscale_cam, use_rgb=True)
29    plt.figure(figsize=(6, 6))
30    plt.imshow(visualization)
31    plt.axis("off")
32    plt.title("Grad-CAM Heatmap")
33    plt.savefig('test.png')
34
35if __name__ == '__main__':
36    main()