Is there an existing issue for this?
- [X] I have searched the existing issues
Bug summary
tio.inference.GridSampler and GridAggregator do not allow the model output to be smaller than the input.
I was going to submit this as a feature request before making a PR, however, I realised that tio actually supports this depending on patch_overlap and overlap_mode so I believe this should be a bug.
Code for reproduction
# This is not a MWE but a test named `test_inference_smaller.py`
from torch.utils.data import DataLoader
from torchio import DATA
from torchio import LOCATION
from torchio.data.inference import GridAggregator
from torchio.data.inference import GridSampler
from ...utils import TorchioTestCase
class TestInference(TorchioTestCase):
    """Tests for `inference` module."""
    def test_inference_no_padding(self):
        self.try_inference(None)
    def test_inference_padding(self):
        self.try_inference(3)
    def try_inference(self, padding_mode):
        for mode in ["crop", "average", "hann"]:
            for n in 17, 27:
                patch_size = 10, 15, n
                patch_overlap = 0, 0, 0 # <------------- this is important and different from the usual test
                batch_size = 6
                grid_sampler = GridSampler(
                    self.sample_subject,
                    patch_size,
                    patch_overlap,
                    padding_mode=padding_mode,
                )
                aggregator = GridAggregator(grid_sampler, overlap_mode=mode)
                patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
                for patches_batch in patch_loader:
                    input_tensor = patches_batch['t1'][DATA]
                    locations = patches_batch[LOCATION]
                    logits = model(input_tensor)  # some model
                    outputs = logits
                    # 
                    i_ini, j_ini, k_ini = 1, 1, 1
                    i_fin, j_fin, k_fin = patch_size[0]-1, patch_size[1]-1, patch_size[2]-1
                    outputs = outputs[
                        :,
                        :,
                        i_ini:i_fin,
                        j_ini:j_fin,
                        k_ini:k_fin,
                    ]
                    aggregator.add_batch(outputs, locations)
                output = aggregator.get_output_tensor()
                assert (output == -5).all()
                assert output.shape == self.sample_subject.t1.shape
def model(tensor):
    tensor[:] = -5
    return tensor
Actual outcome
This raises a RuntimeError if patch_overlap is smaller than the difference between input and output, and the overlap mode is anything but crop
Below is the output of running pytest tests/data/inference/test_inference_smaller.py
Error messages
==================================================================================================== FAILURES =====================================================================================================
_____________________________________________________________________________________ TestInference.test_inference_no_padding _____________________________________________________________________________________
self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_no_padding>
    def test_inference_no_padding(self):
>       self.try_inference(None)
test_inference_smaller.py:13: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
    aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f8353643bb0>
batch_tensor = tensor([[[[[-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5..., -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0,  0,  0, 10, 15, 17],
       [ 0,  0, 13, 10, 15, 30],
       [ 0,  5,  0, 10, 20, 17],
       [ 0,  5, 13, 10, 20, 30]])
    def add_batch(
            self,
            batch_tensor: torch.Tensor,
            locations: torch.Tensor,
    ) -> None:
        """Add batch processed by a CNN to the output prediction volume.
    
        Args:
            batch_tensor: 5D tensor, typically the output of a convolutional
                neural network, e.g. ``batch['image'][torchio.DATA]``.
            locations: 2D tensor with shape :math:`(B, 6)` representing the
                patch indices in the original image. They are typically
                extracted using ``batch[torchio.LOCATION]``.
        """
        batch = batch_tensor.cpu()
        locations = locations.cpu().numpy()
        patch_sizes = locations[:, 3:] - locations[:, :3]
        # There should be only one patch size
        assert len(np.unique(patch_sizes, axis=0)) == 1
        input_spatial_shape = tuple(batch.shape[-3:])
        target_spatial_shape = tuple(patch_sizes[0])
        if input_spatial_shape != target_spatial_shape:
            message = (
                f'The shape of the input batch, {input_spatial_shape},'
                ' does not match the shape of the target location,'
                f' which is {target_spatial_shape}'
            )
>           raise RuntimeError(message)
E           RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
______________________________________________________________________________________ TestInference.test_inference_padding _______________________________________________________________________________________
self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_padding>
    def test_inference_padding(self):
>       self.try_inference(3)
test_inference_smaller.py:16: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
    aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f835149ca90>
batch_tensor = tensor([[[[[-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5..., -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0,  0,  0, 10, 15, 17],
       [ 0,  0, 13, 10, 15, 30],
       [ 0,  5,  0, 10, 20, 17],
       [ 0,  5, 13, 10, 20, 30]])
    def add_batch(
            self,
            batch_tensor: torch.Tensor,
            locations: torch.Tensor,
    ) -> None:
        """Add batch processed by a CNN to the output prediction volume.
    
        Args:
            batch_tensor: 5D tensor, typically the output of a convolutional
                neural network, e.g. ``batch['image'][torchio.DATA]``.
            locations: 2D tensor with shape :math:`(B, 6)` representing the
                patch indices in the original image. They are typically
                extracted using ``batch[torchio.LOCATION]``.
        """
        batch = batch_tensor.cpu()
        locations = locations.cpu().numpy()
        patch_sizes = locations[:, 3:] - locations[:, :3]
        # There should be only one patch size
        assert len(np.unique(patch_sizes, axis=0)) == 1
        input_spatial_shape = tuple(batch.shape[-3:])
        target_spatial_shape = tuple(patch_sizes[0])
        if input_spatial_shape != target_spatial_shape:
            message = (
                f'The shape of the input batch, {input_spatial_shape},'
                ' does not match the shape of the target location,'
                f' which is {target_spatial_shape}'
            )
>           raise RuntimeError(message)
E           RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
================================================================================================ warnings summary =================================================================================================
test_inference_smaller.py: 16 warnings
  /home/wahab/miniconda3/envs/torchioenv/lib/python3.10/site-packages/SimpleITK/extra.py:183: DeprecationWarning: Converting `np.character` to a dtype is deprecated. The current result is `np.dtype(np.str_)` which is not strictly correct. Note that `np.character` is generally deprecated and 'S1' should be used.
    _np_sitk = {np.dtype(np.character): sitkUInt8,
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================= short test summary info =============================================================================================
FAILED test_inference_smaller.py::TestInference::test_inference_no_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
FAILED test_inference_smaller.py::TestInference::test_inference_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
========================================================================================= 2 failed, 16 warnings in 0.97s =========================================================================================
Expected outcome
I believe tio should be able to handle smaller outputs. My model predictions are terrible even with averaging or hann windowing. Unfortunately most popular model libraries (such as the great monai) only provide models with the same output size and input. But it is crucial in my application to let the model see a bigger input ROI than semantic label outputs - by padding convolutions, as this gives context for the prediction. The original unet paper uses padded convolutions for smaller outputs than inputs.
I am going to make a PR tomorrow to add a fix for this, my planned changes are to only change the aggregator. This can be fixed with only changes to GridAggregator and the sampler can be left the same :
- [x] Check if the aggregator input is smaller than the sampler output in `GridAggregator.add_batch()' before comparing it to the location patch size
- [x] Create a variable in aggregator called patch_diffswhich is the difference betweeninput_spatial_shapeandtarget_spatial_shape
- [x] Change each dimension of self.patch_overlaptopatch_diffsif it is smaller
- [ ] ~Edit each location before cropping by adding half the diffs from i_inietc and removing half the diffs fromi_fin~
- [x] Write a new unit test (Let me know if this can be improved)
If you see an issue with this happening behind the scenes, should model_output_size be added as an argument to GridAggreator or GridSampler? Or should Aggregator raise a warning if it detects it behind the scenes?
This is a bit confusing even in the code as the models output is the aggregators input, I've tried to be clear here, let me know if I havent.
System info
Platform:   Linux-5.4.0-131-generic-x86_64-with-glibc2.27
TorchIO:    0.18.86
PyTorch:    1.13.0+cu117
SimpleITK:  2.2.0 (ITK 5.3)
NumPy:      1.23.4
Python:     3.10.8 (main, Nov  4 2022, 13:48:29) [GCC 11.2.0]
bug