import itertools
import math
from functools import lru_cache, partial
import numpy as np
from chunkblocks.iterators import UnitBFSIterator
[docs]@lru_cache(maxsize=None)
def all_borders(dimensions):
return tuple(itertools.product(range(0, dimensions), (-1, 1)))
[docs]def sub(slice_left, slice_right):
"""
Removes the right slice from the left. Does NOT account for slices on the right that do not touch the border of the
left
"""
start = 0
stop = 0
if slice_left.start == slice_right.start:
start = min(slice_left.stop, slice_right.stop)
stop = max(slice_left.stop, slice_right.stop)
if slice_left.stop == slice_right.stop:
start = min(slice_left.start, slice_right.start)
stop = max(slice_left.start, slice_right.start)
return slice(start, stop)
[docs]class Chunk(object):
__slots__ = ('unit_index', 'slices', 'offset', 'data', 'all_borders', 'block')
def __init__(self, block, unit_index):
self.block = block
self.unit_index = unit_index
self.slices = block.unit_index_to_slices(unit_index)
self.offset = tuple(s.start for s in self.slices)
self.data = None
self.all_borders = all_borders(len(self.shape))
@property
def overlap(self):
return self.block.overlap
@property
def shape(self):
return self.block.chunk_shape
[docs] def squeeze_slices(self, slices):
"""
Ensure that the slices match the maximum permissible bounds of this particular chunk.
Used for datasources that are unable to pad or handle out of bounds arrays properly
"""
return tuple(
slice(
None if sl.start is None else sl.start if sl.start > bounds.start else bounds.start,
None if sl.stop is None else sl.stop if sl.stop < bounds.stop else bounds.stop,
)
for bounds, sl in zip(self.data.bounds(), slices)
)
[docs] def match_datasource_dimensions(self, datasource, slices):
matched_slices = (slice(None),) * (len(datasource.shape) - len(slices)) + slices
if self.data is not None:
matched_slices = self.squeeze_slices(matched_slices)
return matched_slices
[docs] def load_data(self, datasource, slices=None):
if slices is None:
slices = self.slices
slices = self.match_datasource_dimensions(datasource, slices)
if self.data is None:
self.data = datasource[slices].copy()
else:
self.data[slices] = datasource[slices]
return self
[docs] def dump_data(self, datasource, slices=None):
if slices is None:
slices = self.slices
slices = self.match_datasource_dimensions(datasource, slices)
datasource[slices] = self.data[slices]
return self
[docs] def copy_data(self, source, destination, slices=None):
if slices is None:
slices = self.slices
slices = self.match_datasource_dimensions(destination, slices)
destination[slices] = source[slices]
[docs] def __eq__(self, other):
return isinstance(other, Chunk) and self.unit_index == other.unit_index
[docs] def __hash__(self):
return hash(self.unit_index)
[docs] def core_slices(self, borders=None):
"""
Returns a list of non-intersecting slices that is excluded by the requested borders. Borders is a list of
tuples:
(dimension index of border, border direction)
Border direction is specified by -1 to represent the border in the negative index direction and +1 for the
positive index direction.
"""
if borders is None:
borders = self.all_borders
core_slices = list(self.slices)
for border, direction in borders:
core_slice = core_slices[border]
if direction < 0:
core_slice = slice(core_slice.start + self.overlap[border], core_slice.stop)
else:
core_slice = slice(core_slice.start, core_slice.stop - self.overlap[border])
core_slices[border] = core_slice
return tuple(core_slices)
[docs] def border_slices(self, borders=None, nonintersecting=True):
"""
Returns a list of slices that cover the requested borders.
:param borders: list of tuples indicating (dimension index of border, border direction)
When no borders are given, return all borders.
Border direction is specified by -1 to represent the border in the negative index direction and +1 for the
positive index direction.
:param nonintersecting: if set to False, will return slices that will account for each index only *once*. if set to
True, will indescriminately return the largest slices that will include the corners and edges more than once.
"""
if borders is None:
borders = self.all_borders
border_slices = []
processed_dimensions = set()
remainders = list(self.slices)
for border, direction in borders:
if direction < 0:
border_slice = slice(self.slices[border].start, self.slices[border].start +
self.overlap[border])
else:
border_slice = slice(self.slices[border].stop - self.overlap[border],
self.slices[border].stop)
new_slices = tuple(
border_slice if idx == border else
remainders[idx] if idx in processed_dimensions else
self.slices[idx]
for idx in range(0, len(self.slices))
)
if nonintersecting:
remainders[border] = sub(remainders[border], new_slices[border])
border_slices.append(new_slices)
processed_dimensions.add(border)
return border_slices
[docs]class Block(object):
def __init__(self, bounds=None, offset=None, num_chunks=None, chunk_shape=None, overlap=None, base_iterator=None):
"""
Create a block which is used to addres chunks. Must specify either bounds or (offset and num_chunks) to
determine the size of the dataset.
"""
if not overlap:
overlap = tuple([0] * len(chunk_shape))
self.overlap = overlap
self.chunk_shape = tuple(chunk_shape)
if not base_iterator:
base_iterator = UnitBFSIterator()
self.base_iterator = base_iterator
self.strides = tuple((c_shape - olap) for c_shape, olap in zip(self.chunk_shape, self.overlap))
contains_bounds = bounds is not None
contains_offset = offset is not None and num_chunks is not None
if not contains_bounds and not contains_offset:
raise ValueError('Either bounds or offset/num_chunks must be specified')
if contains_bounds:
self.offset = tuple(s.start for s in bounds)
self.bounds = tuple(bounds)
self.shape = tuple(b.stop - b.start for b in self.bounds)
self.num_chunks = tuple((shp - olap) // s for shp, olap, s in zip(
self.shape, self.overlap, self.strides))
if contains_offset:
bounds = tuple(slice(o, o + chks * st + olap) for o, chks, st, olap in zip(
offset, num_chunks, self.strides, self.overlap))
shape = tuple(chunks * st + olap for chunks, st, olap in zip(num_chunks, self.strides, self.overlap))
if contains_bounds:
assert self.bounds == bounds, "Received both bounds and offset/num_chunks that do not match"
assert self.shape == shape, "Received both bounds and offset/num_chunks that do not match"
assert self.num_chunks == num_chunks, "Received both bounds and offset/num_chunks that do not match"
else:
self.offset = offset
self.bounds = bounds
self.shape = shape
self.num_chunks = num_chunks
self.bounds = bounds
Block.verify_size(self.num_chunks, self.chunk_shape, self.shape, self.overlap)
self.checkpoints = []
self.unit_index_to_chunk = partial(Chunk, self)
[docs] def unit_index_to_slices(self, index):
return tuple(slice(b.start + idx * s, b.start + idx * s + c_shape) for b, idx, s, c_shape in zip(
self.bounds, index, self.strides, self.chunk_shape))
[docs] def chunk_slices_to_unit_index(self, slices):
"""
Get the corresponding chunk index for this chunk_slice
"""
# remove dimension for channel
slices = slices[-len(self.chunk_shape):]
return tuple((slice.start - b.start) // s for b, s, slice in zip(self.bounds, self.strides, slices))
[docs] def slices_to_unit_indices(self, slices):
"""
Get the corresponding unit indices that cover these slices
"""
# remove dimension for channel
slices = slices[-len(self.chunk_shape):]
return itertools.product(
*[
range(
# set start 0 if slice begins in first overlap area to prevent negative index
# otherwise take floor div for the number of strides from bound start - offset of overlap
0 if sl.start is None or sl.start < b.start + o else (sl.start - b.start - o) // s,
# set end to chunks if slice ends in the last overlap area to prevent index > chunks
# otherwise take ceil div for the number of strides from start (no offset needed)
chunks if sl.stop is None or sl.stop >= b.stop - o else math.ceil((sl.stop - b.start) / s)
) for b, s, chunks, sl, o in zip(self.bounds, self.strides, self.num_chunks, slices, self.overlap)
]
)
[docs] def slices_to_chunks(self, slices):
"""
Get the corresponding chunks that cover these slices
"""
return map(self.unit_index_to_chunk, self.slices_to_unit_indices(slices))
[docs] @staticmethod
def verify_size(num_chunks, chunk_shape, shape, overlap):
for chunks, c_shape, shp, olap in zip(num_chunks, chunk_shape, shape, overlap):
if chunks * (c_shape - olap) + olap != shp:
raise ValueError('Data size %s divided by %s with overlap %s does not divide evenly' % (
shape, chunk_shape, overlap))
[docs] def ensure_checkpoint_stage(self, stage):
try:
return self.checkpoints[stage]
except IndexError:
while len(self.checkpoints) < stage + 1:
self.checkpoints.append(np.zeros(self.num_chunks, dtype=np.bool))
return self.checkpoints[stage]
[docs] def checkpoint(self, chunk, stage=0):
self.ensure_checkpoint_stage(stage)[chunk.unit_index] = True
[docs] def get_all_neighbors(self, chunk):
return map(self.unit_index_to_chunk,
self.base_iterator.get_all_neighbors(chunk.unit_index, max=self.num_chunks))
[docs] def is_checkpointed(self, chunk, stage=0):
return self.ensure_checkpoint_stage(stage)[chunk.unit_index]
[docs] def all_neighbors_checkpointed(self, chunk, stage=0):
return self.all_checkpointed(self.get_all_neighbors(chunk), stage)
[docs] def all_checkpointed(self, chunks, stage=0):
checkpoints = self.ensure_checkpoint_stage(stage)
return all(checkpoints[chunk.unit_index] for chunk in chunks)
[docs] def chunk_iterator(self, start=None):
if start is None:
start_index = (0,) * len(self.shape)
elif isinstance(start, Chunk):
start_index = start.unit_index
else:
start_index = start
yield from map(self.unit_index_to_chunk, self.base_iterator.get(start_index, self.num_chunks))
[docs] def core_slices(self, chunk):
"""
Returns the slices of the chunk that corresponds to the block's core that has no overlap with other blocks.
"""
intersect_slices = []
for s, b, olap, idx in zip(chunk.slices, self.bounds, self.overlap, range(0, len(chunk.slices))):
if s.start == b.start:
intersect_slices.append(slice(s.start + olap, s.stop))
elif s.stop == b.stop:
intersect_slices.append(slice(s.start, s.stop - olap))
else:
intersect_slices.append(s)
return tuple(self.remove_chunk_overlap(chunk, intersect_slices))
[docs] def overlap_borders(self, chunk):
"""
Get a list of borders in the chunk that correspond to the block's overlap region.
Returns list of borders in the form of tuples:
(dimension index of border, border direction)
Border direction is specified by -1 to represent the border in the negative index direction and +1 for the
positive index direction.
See py:method::overlap_slices(chunk) for usage
"""
# determine the common intersect slices within the chunk
borders = []
for s, b, olap, idx in zip(chunk.slices, self.bounds, self.overlap, range(0, len(chunk.slices))):
if s.start == b.start:
borders.append((idx, -1))
elif s.stop == b.stop:
borders.append((idx, 1))
return borders
[docs] def remove_chunk_overlap(self, chunk, overlapped_slices):
"""
Modify slices to remove the common intersection of the chunks within the block. Common intersections are
excluded in a index first fashion, i.e. the slices do not include the portion of the data that will be
accounted for by the next chunk ( next chunk is of a greater index ).
See py:method::overlap_slices_with_borders(chunk) for usage
"""
return tuple(
slice(o_slice.start, o_slice.stop - olap) if o_slice.stop == s.stop and o_slice.stop != b.stop else
o_slice
for s, o_slice, olap, b in zip(chunk.slices, overlapped_slices, self.overlap, self.bounds)
)
[docs] def overlap_slices(self, chunk):
"""
Get a list of the slices in the chunk that correspond to the block's overlap region (i.e. the borders of the
block) with chunk overlaps removed.
See py:method::overlap_slices(chunk) for more details
"""
return self.overlap_slices_with_borders(chunk, self.overlap_borders(chunk))
[docs] def overlap_chunk_slices(self, chunk):
"""
Get a list of the all the chunks overlaps with overlaps across chunks accounted for only once.
See py:method::overlap_slices_with_borders(chunk, borders) for more details
"""
return self.overlap_slices_with_borders(chunk, all_borders(len(self.shape)))
[docs] def overlap_slices_with_borders(self, chunk, borders):
"""
Get a list of the slices in the chunk that correspond the input borders with chunk overlaps removed.
If we have a block:
dimensions: 7x7
chunk_shape: 3x3
overlap: 1x1
This should result in 3x3 chunks. When this function is called with each of these chunks, slices that cover the
overlap region are returned with no duplicates. Additionally, overlaps across chunks are excluded in a index
first fashion, i.e. the slices do not include the portion of the data that had should be accounted for by the
next chunk ( next chunk meaning of a greater index ).
At the non corner chunks, we expect to return a single tuple of slices that
cover the overlap region, i.e.(not actual format, dictionary used for clarity)
x: slice(0, 1), y: slice(2, 5)
For corner chunks, this takes care of overlapping areas so they do not get counted twice. For example, for the
chunk at position (0, 0), we should expect to return the tuples of slices:
x1: slice(0, 3), y1: slice(0, 1)
x2: slice(0, 1), y2: slice(1, 3)]
WARNING: not tested for dimensions > 3.
"""
return [
slices for slices in [
self.remove_chunk_overlap(chunk, overlapped_slice)
for overlapped_slice in chunk.border_slices(borders)
] if all(s.stop != s.start for s in slices)
]