1import torch
2import numpy as np
3import matplotlib.pyplot as plt
4import cv2
5from pytorch_grad_cam import GradCAM
6from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
7from methane_super_emitters.dataset_stats import normalize
8import click
9
10@click.command()
11@click.option('-i', '--input-file', help='Input file')
12@click.option('-o', '--output-file', help='Output file')
13@click.option('-c', '--checkpoint', help='Model checkpoint file')
14def main(input_file, output_file, checkpoint):
15 model = SuperEmitterDetector.load_from_checkpoint(checkpoint)
16 model.eval()
17 target_layer = model.conv_layers[-1]
18 cam = GradCAM(model=model, target_layers=[target_layer])
19 data = np.load(input_file)
20 example_image = normalize(data, ['methane', 'u10', 'v10', 'qa'])
21 example_image = torch.tensor(np.array([example_image]), dtype=torch.float)
22 input_tensor = example_image.requires_grad_(True)
23 grayscale_cam = cam(input_tensor=input_tensor)[0]
24 grayscale_cam = np.maximum(grayscale_cam, 0)
25 grayscale_cam = grayscale_cam / grayscale_cam.max()
26 input_image_np = example_image[0].mean(0).cpu().numpy()
27 input_image_np = cv2.normalize(input_image_np, None, 0, 1, cv2.NORM_MINMAX)
28 visualization = show_cam_on_image(input_image_np, grayscale_cam, use_rgb=True)
29 plt.figure(figsize=(6, 6))
30 plt.imshow(visualization)
31 plt.axis("off")
32 plt.title("Grad-CAM Heatmap")
33 plt.savefig('test.png')
34
35if __name__ == '__main__':
36 main()