-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtile_list_batch_iterator.py
More file actions
174 lines (150 loc) · 5.85 KB
/
tile_list_batch_iterator.py
File metadata and controls
174 lines (150 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import random
import numpy as np
import pandas as pd
from PIL import Image
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.utils import to_categorical
from typing import Generator, List, Tuple
from .utils import load_tile
from .tile_list import TileList
from .global_variables import CLASS_LABEL_TO_INDEX_MAP, NUM_CLASSES
class TileListBatchIterator:
"""
An iterator class to iterate over all tiles contained in a TileList object.
Attributes
----------
_tile_list: TileList
TileList object containing all relevant tiles.
_tiles_dir: str
Directory where the tiles are stored.
_tile_size: int
Size of the (rectangular) tiles in pixels.
_required_pixel_spacing: float
Required pixel spacing in mm/px.
_batch_size: int
Number of tiles to be in one batch.
"""
def __init__(self, tile_list: TileList, tiles_dir: str,
tile_size: int, required_pixel_spacing: float, batch_size: int):
"""
Constructor of TileListBatchIterator.
Parameters
----------
tile_list: TileList
TileList object containing all relevant tiles.
tiles_dir: str
Directory where the tiles are stored.
tile_size: int
Size of the (rectangular) tiles in pixels.
required_pixel_spacing: float
Required pixel spacing in mm/px.
batch_size: int
Number of tiles to be in one batch.
"""
self._tile_list = tile_list
self._tiles_dir = tiles_dir
self._tile_size = tile_size
self._required_pixel_spacing = required_pixel_spacing
self._batch_size = batch_size
@staticmethod
def _scale_tile(tile: Image.Image) -> np.ndarray:
"""
Scales image values to [-1, 1], the expected input for InceptionV3 network
Parameters
----------
tile: Image.Image
Tile to be rescaled.
Returns
-------
np.ndarray
Tile with rescaled values.
"""
return (img_to_array(tile) / 127.5) - 1.0
@staticmethod
def _augment(tile: Image.Image) -> Image.Image:
"""
Performs simple data augmentation by random rotation of the tile.
Parameters
----------
tile: Image.Image
Tile to be augmented.
Returns
-------
np.ndarray
Augmented tile.
"""
rotation_angle = random.choice([90, 180, 270, 360])
return tile.rotate(angle=rotation_angle)
def __iter__(self):
self._tile_index = 0
return self
def __next__(self) -> Tuple[np.ndarray, List[Tuple[str, Tuple[int,int]]]]:
"""
Prepares next batch of tiles.
Returns
-------
tuple
Tuple with the first element being all tiles in the batch as np.ndarray and
the second element being a list of tile information corresponding to the tiles.
"""
batch_images = np.empty((self._batch_size, self._tile_size, self._tile_size, 3))
batch_tile_infos = [None] * self._batch_size
curr_batch_size = 0
while self._tile_index < len(self._tile_list) and curr_batch_size < self._batch_size:
tile_info = self._tile_list.get_tile_info(self._tile_index)
# Open and prepare tile
tile = load_tile(self._tiles_dir, tile_info)
tile = self._augment(tile)
tile = self._scale_tile(tile)
# Add tile to batch
batch_images[curr_batch_size] = tile[np.newaxis, ...] # add batch dimension
batch_tile_infos[curr_batch_size] = tile_info
curr_batch_size += 1
self._tile_index += 1
if curr_batch_size > 0:
batch_images.resize((curr_batch_size, self._tile_size, self._tile_size, 3))
batch_tile_infos = batch_tile_infos[0:curr_batch_size]
return (batch_images, batch_tile_infos)
else:
raise StopIteration
def get_tile_generator(tile_list: TileList, tiles_dir: str, tile_size: int,
required_pixel_spacing: float, batch_size: int, num_classes: int,
slides_metadata: pd.DataFrame,
shuffle: bool=True) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
"""
Yields batches for neural network training / validation.
Parameters
----------
tile_list: TileList
TileList object containing all relevant tiles.
tiles_dir: str
Directory where the tiles are stored.
tile_size: int
Size of the (rectangular) tiles in pixels.
required_pixel_spacing: float
Required pixel spacing in mm/px.
batch_size: int
Number of tiles to be in one batch.
num_classes: int
Number of classes in the current classification problem.
slides_metadata: pd.DataFrame
Metadata of the slides that were used in the prediction.
shuffle: bool
Bool indicating whether to shuffle tiles.
Returns
-------
Generator
Generator providing batches for training / validation.
"""
while True:
if shuffle:
tile_list.shuffle()
batch_iterator = TileListBatchIterator(tile_list, tiles_dir, tile_size,
required_pixel_spacing, batch_size)
for batch_x, batch_tile_infos in batch_iterator:
batch_y = np.empty((len(batch_tile_infos), num_classes))
for i, (image_id, _ ) in enumerate(batch_tile_infos):
slide_metadata = slides_metadata.loc[image_id]
reference_value = CLASS_LABEL_TO_INDEX_MAP[slide_metadata['reference_class_label']]
batch_y[i] = to_categorical(reference_value, num_classes)
yield (batch_x, batch_y)