methane_super_emitters.datamodule

 1import lightning as L
 2from torch.utils.data import random_split, DataLoader
 3from methane_super_emitters.dataset import (
 4    TROPOMISuperEmitterDataset,
 5    TROPOMISuperEmitterLocatorDataset,
 6)
 7
 8
 9class TROPOMISuperEmitterDataModule(L.LightningDataModule):
10    def __init__(self, data_dir, fields, batch_size=32, locator=False):
11        super().__init__()
12        self.data_dir = data_dir
13        self.batch_size = batch_size
14        if locator:
15            self.dataset = TROPOMISuperEmitterLocatorDataset(
16                self.data_dir, fields=fields
17            )
18        else:
19            self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields)
20        self.train_set, self.val_set, self.test_set = random_split(
21            self.dataset, [0.7, 0.15, 0.15]
22        )
23
24    def setup(self, stage):
25        pass
26
27    def train_dataloader(self):
28        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
29
30    def val_dataloader(self):
31        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
32
33    def test_dataloader(self):
34        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
35
36    def predict_dataloader(self):
37        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
38
39    def teardown(self, stage):
40        pass
class TROPOMISuperEmitterDataModule(lightning.pytorch.core.datamodule.LightningDataModule):
10class TROPOMISuperEmitterDataModule(L.LightningDataModule):
11    def __init__(self, data_dir, fields, batch_size=32, locator=False):
12        super().__init__()
13        self.data_dir = data_dir
14        self.batch_size = batch_size
15        if locator:
16            self.dataset = TROPOMISuperEmitterLocatorDataset(
17                self.data_dir, fields=fields
18            )
19        else:
20            self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields)
21        self.train_set, self.val_set, self.test_set = random_split(
22            self.dataset, [0.7, 0.15, 0.15]
23        )
24
25    def setup(self, stage):
26        pass
27
28    def train_dataloader(self):
29        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
30
31    def val_dataloader(self):
32        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
33
34    def test_dataloader(self):
35        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
36
37    def predict_dataloader(self):
38        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
39
40    def teardown(self, stage):
41        pass

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example::

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...
TROPOMISuperEmitterDataModule(data_dir, fields, batch_size=32, locator=False)
11    def __init__(self, data_dir, fields, batch_size=32, locator=False):
12        super().__init__()
13        self.data_dir = data_dir
14        self.batch_size = batch_size
15        if locator:
16            self.dataset = TROPOMISuperEmitterLocatorDataset(
17                self.data_dir, fields=fields
18            )
19        else:
20            self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields)
21        self.train_set, self.val_set, self.test_set = random_split(
22            self.dataset, [0.7, 0.15, 0.15]
23        )

Attributes: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.

data_dir
batch_size
def setup(self, stage):
25    def setup(self, stage):
26        pass

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args: stage: either 'fit', 'validate', 'test', or 'predict'

Example::

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
def train_dataloader(self):
28    def train_dataloader(self):
29        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

- download in `prepare_data()`
- process and split in `setup()`

However, the above are only necessary for distributed processing.

do not assign state in prepare_data

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

def val_dataloader(self):
31    def val_dataloader(self):
32        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It's recommended that all data downloads and preparation happen in prepare_data().

  • ~lightning.pytorch.trainer.trainer.Trainer.fit()
  • ~lightning.pytorch.trainer.trainer.Trainer.validate()
  • prepare_data()
  • setup()

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note: If you don't need a validation dataset and a validation_step(), you don't need to implement this method.

def test_dataloader(self):
34    def test_dataloader(self):
35        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

For data processing use the following pattern:

- download in `prepare_data()`
- process and split in `setup()`

However, the above are only necessary for distributed processing.

do not assign state in prepare_data

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note: If you don't need a test dataset and a test_step(), you don't need to implement this method.

def predict_dataloader(self):
37    def predict_dataloader(self):
38        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

It's recommended that all data downloads and preparation happen in prepare_data().

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return: A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

def teardown(self, stage):
40    def teardown(self, stage):
41        pass

Called at the end of fit (train + validate), validate, test, or predict.

Args: stage: either 'fit', 'validate', 'test', or 'predict'