methane_super_emitters.datamodule
1import lightning as L 2from torch.utils.data import random_split, DataLoader 3from methane_super_emitters.dataset import ( 4 TROPOMISuperEmitterDataset, 5 TROPOMISuperEmitterLocatorDataset, 6) 7 8 9class TROPOMISuperEmitterDataModule(L.LightningDataModule): 10 def __init__(self, data_dir, fields, batch_size=32, locator=False): 11 super().__init__() 12 self.data_dir = data_dir 13 self.batch_size = batch_size 14 if locator: 15 self.dataset = TROPOMISuperEmitterLocatorDataset( 16 self.data_dir, fields=fields 17 ) 18 else: 19 self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields) 20 self.train_set, self.val_set, self.test_set = random_split( 21 self.dataset, [0.7, 0.15, 0.15] 22 ) 23 24 def setup(self, stage): 25 pass 26 27 def train_dataloader(self): 28 return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True) 29 30 def val_dataloader(self): 31 return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False) 32 33 def test_dataloader(self): 34 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False) 35 36 def predict_dataloader(self): 37 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False) 38 39 def teardown(self, stage): 40 pass
10class TROPOMISuperEmitterDataModule(L.LightningDataModule): 11 def __init__(self, data_dir, fields, batch_size=32, locator=False): 12 super().__init__() 13 self.data_dir = data_dir 14 self.batch_size = batch_size 15 if locator: 16 self.dataset = TROPOMISuperEmitterLocatorDataset( 17 self.data_dir, fields=fields 18 ) 19 else: 20 self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields) 21 self.train_set, self.val_set, self.test_set = random_split( 22 self.dataset, [0.7, 0.15, 0.15] 23 ) 24 25 def setup(self, stage): 26 pass 27 28 def train_dataloader(self): 29 return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True) 30 31 def val_dataloader(self): 32 return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False) 33 34 def test_dataloader(self): 35 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False) 36 37 def predict_dataloader(self): 38 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False) 39 40 def teardown(self, stage): 41 pass
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example::
import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset
class MyDataModule(L.LightningDataModule):
def prepare_data(self):
# download, IO, etc. Useful with shared filesystems
# only called on 1 GPU/TPU in distributed
...
def setup(self, stage):
# make assignments here (val/train/test split)
# called on every process in DDP
dataset = RandomDataset(1, 100)
self.train, self.val, self.test = data.random_split(
dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return data.DataLoader(self.train)
def val_dataloader(self):
return data.DataLoader(self.val)
def test_dataloader(self):
return data.DataLoader(self.test)
def on_exception(self, exception):
# clean up state after the trainer faced an exception
...
def teardown(self):
# clean up state after the trainer stops, delete files...
# called on every process in DDP
...
11 def __init__(self, data_dir, fields, batch_size=32, locator=False): 12 super().__init__() 13 self.data_dir = data_dir 14 self.batch_size = batch_size 15 if locator: 16 self.dataset = TROPOMISuperEmitterLocatorDataset( 17 self.data_dir, fields=fields 18 ) 19 else: 20 self.dataset = TROPOMISuperEmitterDataset(self.data_dir, fields=fields) 21 self.train_set, self.val_set, self.test_set = random_split( 22 self.dataset, [0.7, 0.15, 0.15] 23 )
Attributes: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Args:
stage: either 'fit'
, 'validate'
, 'test'
, or 'predict'
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(self, stage):
data = load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
28 def train_dataloader(self): 29 return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>
.
The dataloader you return will not be reloaded unless you set
:paramref:~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs
to
a positive integer.
For data processing use the following pattern:
- download in `prepare_data()`
- process and split in `setup()`
However, the above are only necessary for distributed processing.
do not assign state in prepare_data
~lightning.pytorch.trainer.trainer.Trainer.fit()
prepare_data()
setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
31 def val_dataloader(self): 32 return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>
.
The dataloader you return will not be reloaded unless you set
:paramref:~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs
to
a positive integer.
It's recommended that all data downloads and preparation happen in prepare_data()
.
~lightning.pytorch.trainer.trainer.Trainer.fit()
~lightning.pytorch.trainer.trainer.Trainer.validate()
prepare_data()
setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note:
If you don't need a validation dataset and a validation_step()
, you don't need to
implement this method.
34 def test_dataloader(self): 35 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>
.
For data processing use the following pattern:
- download in `prepare_data()`
- process and split in `setup()`
However, the above are only necessary for distributed processing.
do not assign state in prepare_data
~lightning.pytorch.trainer.trainer.Trainer.test()
prepare_data()
setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note:
If you don't need a test dataset and a test_step()
, you don't need to implement
this method.
37 def predict_dataloader(self): 38 return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>
.
It's recommended that all data downloads and preparation happen in prepare_data()
.
~lightning.pytorch.trainer.trainer.Trainer.predict()
prepare_data()
setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Return:
A torch.utils.data.DataLoader
or a sequence of them specifying prediction samples.