zarr_parallel.region

Classes

zarr_parallel.region.RegionWorker

Source code in zarr_parallel/region.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class RegionWorker:
    def __init__(self, id: str, config: str, heartbeat_timeout: Union[int,None] = None):

        with open(config) as f:
            content = json.load(f)

        self.id = int(id)
        self.dsinfo      = content['dataset']
        self.transforms  = content['common']['pre_transforms']
        self.variables   = content['variables']
        self.region_isel = content['region_info']['region_isel']

        # Non-tiled datasets will have the same parallelisable and source dims.
        self.parallelisable_dims  = content['region_info']['dims']
        self.fine_dims   = content['region_info'].get('fine_dims',{})
        self.source_dims = list(self.region_isel.keys())

        self.tiled = list(self.parallelisable_dims.keys()) != self.source_dims

        self.source_chunks = content['source_chunks']
        self.output_chunks = content['output_chunks']
        self.memory_limit = content['memory_limit']

        self.heartbeat = heartbeat_timeout

        # Determine coordinate/region extents
        self.coord_extent, self.region_extent   = self.map_region()

        # Determine my coordinates
        self.coords = self.id_to_coord()

        # Determine my region
        self.region = self.region_from_coords()

        self.dslice = self.resolve_region()

        self._prepare_dataset()

    def map_region(self):
        coord_extent = [math.ceil((v['source_max']-v['source_min'])/v['worker_size']) for v in self.parallelisable_dims.values()]
        region_extent = [int(v['source_max']-v['source_min']) for v in self.parallelisable_dims.values()]

        return coord_extent, region_extent

    def id_to_coord(self):

        if self.id > math.prod(self.coord_extent)-1:
            raise ValueError(f'ID {self.id} invalid for space with {math.prod(self.coord_extent)} tiles')

        coords = []
        for dim in range(len(self.coord_extent)):
            coord = 0

            # Calculate stride product
            strideprod = 1
            for stride in self.coord_extent[dim+1:]:
                strideprod *= stride
            coord += math.floor(self.id/strideprod) % self.coord_extent[dim]

            coords.append(int(coord))
        return coords

    def region_from_coords(self):

        region = {}
        for i, (dim, dinfo) in enumerate(self.parallelisable_dims.items()):

            rmin = dinfo['worker_size']*self.coords[i]
            rmax = dinfo['worker_size']*(self.coords[i]+1)

            if self.coords[i] == self.coord_extent[i]-1:
                rmax = self.region_extent[i]
            region[dim] = slice(rmin,rmax)

        # Add tiled fine dims in the case of additional fine dims
        for dim, dinfo in self.fine_dims.items():
            region[dim] = slice(dinfo['source_min'], dinfo['source_max'])

        return region

    def start_from(self, var: str, ndims: int, chunks: dict) -> int:
        """
        Detect past progress on writing the zarr store

        Limitation: Primary chunk dimension must be the first dimension.
        """

        trailing_zeros = '.'.join(["0" for x in range(ndims-1)])

        primary_dim = list(self.parallelisable_dims.keys())[0]

        zero_chunk = self.parallelisable_dims[primary_dim]['worker_size']/self.parallelisable_dims[primary_dim]['cache_size']
        chunk_id    = self.coords[0] * zero_chunk

        file = f"{self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros
        while os.path.isfile(file):
            file = f"{self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros
            logger.debug(f"Locating {self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros)
            chunk_id += 1

        if chunk_id > zero_chunk:
            logger.info(f"Resuming from chunk ID: {chunk_id}")
        else:
            logger.info(f"Starting from chunk ID: {chunk_id}")
        return chunk_id - zero_chunk

    def write_data_region(self):

        # Open config file as dict

        # Determine coordinates of region
        # Map to slice of total dataset
        # Extract selection 

        # Determine current region

        chunks = self.output_chunks

        # Replace with logging
        logger.info(f'ID: {self.id}')
        logger.info(f'Coords: {self.coords}')
        logger.info(f'Chunks: {chunks}')

        darrs = self.extract_subset()
        for darr in darrs:

            var = darr.name

            logger.info(f"Writing {var} Region")

            # heartbeat required - split into sections 
            # chunking required - split into sections
            start_from = self.start_from(var, len(darr.dims), chunks)

            force_rechunk = chunks != {} and chunks != self.source_chunks
            if not force_rechunk and not self.heartbeat and not start_from and not self.tiled:
                # Write the whole region to the zarr cache
                import pdb; pdb.set_trace()
                darr.to_zarr(
                    self.dsinfo['zarr_cache'], 
                    zarr_format=2, 
                    compute=True, 
                    consolidated=True,
                    region=self.region,
                    write_empty_chunks=True,
                    mode='r+')

            else:
                self._balanced_chunk_write(var, darr, chunks, start_from=start_from)

        logger.info(f'Complete for {self.coords}')

    def _balanced_chunk_write(self, var: str, darr: xr.DataArray, chunks: dict, start_from: int = 0):
        """
        Control the rate of chunk writes based on time/memory requirements
        """

        # Standard approach: single chunk output at a time.

        # Balanced approach:
        # - Chunking invokes numpy arrays - balance memory up to limit
        # - Dask workers invoke heartbeat - balance timeout up to limit
        # - Split chunk writes require max number of chunks per-write that fit within limits

        primary_dim   = list(self.parallelisable_dims.keys())[0]
        chunk_size    = chunks[primary_dim]
        prime_slice   = 0

        # Byte limit for memory
        memory_limit_bytes  = 0.85 * interpret_mem_limit(self.memory_limit)
        # Chunk limit based on memory
        max_mem_batch_chunk = math.floor(memory_limit_bytes / (math.prod(chunks.values()) * 8))
        # Offset to write region into parallel dataset
        region_write_offset = self.coords[0]*self.parallelisable_dims[primary_dim]['worker_size']

        # Limitation: Tiled datasets will always result in 1-1 tile-chunking
        if self.tiled:
            darr = darr.isel(**{
                primary_dim: slice(
                    region_write_offset,
                    region_write_offset + self.parallelisable_dims[primary_dim]['worker_size']
                )
            })
            mem_chunks = 1
            for dim, dinf in self.fine_dims.items():
                chunks[dim] = dinf['source_max'] - dinf['source_min']

                # If source chunks are larger than the tile selection, memory size is based on the source chunks
                mem_chunks *= max(self.source_chunks.get(dim.split('_')[0]), chunks[dim])

            max_mem_batch_chunk = math.floor(memory_limit_bytes / (mem_chunks * 8))

        if max_mem_batch_chunk < 1:
            raise ValueError(
                f'Memory limit too low to process even a single chunk. '
                f'Limit: {self.memory_limit}, Approx Required: {mem_chunks*8/1e6 :.2f} MB')

        chunk_batch = int(max_mem_batch_chunk)

        # Number of chunks to write
        nchunks     = math.ceil(darr[primary_dim].size/chunk_size)

        logger.info(f'Balancing chunk writes for {nchunks} chunks')

        complete = False
        while not complete:

            timings = []

            # Recalculate limits for chunk batch.
            prime_slice_lim = int(prime_slice + chunk_batch*chunk_size)
            # Handle final case + overflowing chunk batch request size
            if prime_slice_lim > darr[primary_dim].size:

                chunk_batch = int((darr[primary_dim].size - prime_slice)/chunk_size)
                prime_slice_lim = int(darr[primary_dim].size)
                complete = True

            timings = [datetime.now()]
            ds_sub = darr.isel(**{primary_dim: slice(prime_slice, prime_slice_lim)}).compute()

            # Append timing for numpy casting
            timings.append(datetime.now())

            ds_region = xr.Dataset(
                {d: ds_sub[d].to_numpy() for d in ds_sub.dims})

            dask_chunks = tuple([chunks[d] for d in chunks.keys()])
            ds_region[var] = xr.DataArray(da.from_array(ds_sub.to_numpy(), chunks=dask_chunks), dims=list(chunks.keys()))

            region_dict = {
                primary_dim:slice(
                    region_write_offset + prime_slice, 
                    region_write_offset + prime_slice_lim
                )
            }
            region_dict.update({d: slice(0, chunks[d]) for d in chunks.keys() if d != primary_dim})

            logger.info(
                f'Writing region '
                f'({prime_slice}, {prime_slice_lim}) -> '
                f'({region_write_offset + prime_slice}, {region_write_offset + prime_slice_lim})'
            )

            ds_region.to_zarr(
                self.dsinfo['zarr_cache'],
                compute=True,
                consolidated=True,
                zarr_format=2,
                region=region_dict,
                mode='r+',
                safe_chunks=False,
            )
            timings.append(datetime.now())

            # Next iteration, update start position
            prime_slice += chunk_batch*chunk_size

            # Increase chunk usage (if possible) or decrease as necessary
            if self.heartbeat is not None:

                max_time = max([(t - timings[0]).total_seconds() for t in timings])

                # Timeout comparison formula - allows increase in chunk batch size if timeout allows
                estm_chunk_limit = max(
                    2, int(abs(
                        ((self.heartbeat*0.85)-max_time)*(chunk_batch)/max_time
                    )
                ))/2

                if max_time < self.heartbeat*0.85:
                    batch_chunk += estm_chunk_limit
                    if batch_chunk > max_mem_batch_chunk:
                        batch_chunk = max_mem_batch_chunk
                    logger.debug(f' > Increased to {batch_chunk} chunks')

                if max_time >= self.heartbeat*0.85:
                    batch_chunk -= estm_chunk_limit
                    logger.debug(f' > Decreased to {batch_chunk} chunks')

        logger.info('All chunks written to zarr store')

    def _prepare_dataset(self):

        self.ds = xr.open_dataset(
            self.dsinfo['uri'],
            engine=self.dsinfo['engine'],
            chunks='auto',
            **self.dsinfo.get('kwargs',{})
        )

    def resolve_region(self) -> dict[slice]:
        """
        Resolve region to determine slice to subset

        Currently superfluous as the region is equal to the slice,
        but if the improvement to directly slice from source is made,
        this function becomes useful again."""

        dslice = {}
        dcount = 0

        # Tiled dataset - special case for resolving the region
        if self.tiled:
            for d, v in self.region_isel.items():
                dslice[d] = slice(v['source_min'], v['source_max'])
            return dslice

        # This creates the pre-tiled slice to apply to the dataset.
        for d, v in self.region_isel.items():
            dmin = v['source_min'] + self.coords[dcount]*v['worker_size']
            dmax = dmin + v['worker_size']

            # Adjust to fit boundary in the case of smaller final chunk
            if self.coords[dcount] == self.coord_extent[dcount]-1:
                dmax = v['source_max']

            dslice[d] = slice(dmin, dmax)
            dcount += 1

        return dslice

    def extract_subset(self) -> xr.DataArray:
        """
        Open a remote dataset and extract an xarray DataArray
        """

        # All selected transforms applied in correct order.
        transformed = apply_transforms(
            self.ds,
            common_transforms=self.transforms,
            variable_transforms=self.variables,
            region_transform=self.dslice
        )
        return transformed['datasets']
Functions
zarr_parallel.region.RegionWorker.extract_subset()

Open a remote dataset and extract an xarray DataArray

Source code in zarr_parallel/region.py
371
372
373
374
375
376
377
378
379
380
381
382
383
def extract_subset(self) -> xr.DataArray:
    """
    Open a remote dataset and extract an xarray DataArray
    """

    # All selected transforms applied in correct order.
    transformed = apply_transforms(
        self.ds,
        common_transforms=self.transforms,
        variable_transforms=self.variables,
        region_transform=self.dslice
    )
    return transformed['datasets']
zarr_parallel.region.RegionWorker.resolve_region()

Resolve region to determine slice to subset

Currently superfluous as the region is equal to the slice, but if the improvement to directly slice from source is made, this function becomes useful again.

Source code in zarr_parallel/region.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def resolve_region(self) -> dict[slice]:
    """
    Resolve region to determine slice to subset

    Currently superfluous as the region is equal to the slice,
    but if the improvement to directly slice from source is made,
    this function becomes useful again."""

    dslice = {}
    dcount = 0

    # Tiled dataset - special case for resolving the region
    if self.tiled:
        for d, v in self.region_isel.items():
            dslice[d] = slice(v['source_min'], v['source_max'])
        return dslice

    # This creates the pre-tiled slice to apply to the dataset.
    for d, v in self.region_isel.items():
        dmin = v['source_min'] + self.coords[dcount]*v['worker_size']
        dmax = dmin + v['worker_size']

        # Adjust to fit boundary in the case of smaller final chunk
        if self.coords[dcount] == self.coord_extent[dcount]-1:
            dmax = v['source_max']

        dslice[d] = slice(dmin, dmax)
        dcount += 1

    return dslice
zarr_parallel.region.RegionWorker.start_from(var, ndims, chunks)

Detect past progress on writing the zarr store

Limitation: Primary chunk dimension must be the first dimension.

Source code in zarr_parallel/region.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def start_from(self, var: str, ndims: int, chunks: dict) -> int:
    """
    Detect past progress on writing the zarr store

    Limitation: Primary chunk dimension must be the first dimension.
    """

    trailing_zeros = '.'.join(["0" for x in range(ndims-1)])

    primary_dim = list(self.parallelisable_dims.keys())[0]

    zero_chunk = self.parallelisable_dims[primary_dim]['worker_size']/self.parallelisable_dims[primary_dim]['cache_size']
    chunk_id    = self.coords[0] * zero_chunk

    file = f"{self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros
    while os.path.isfile(file):
        file = f"{self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros
        logger.debug(f"Locating {self.dsinfo['zarr_cache']}/{var}/{chunk_id}." + trailing_zeros)
        chunk_id += 1

    if chunk_id > zero_chunk:
        logger.info(f"Resuming from chunk ID: {chunk_id}")
    else:
        logger.info(f"Starting from chunk ID: {chunk_id}")
    return chunk_id - zero_chunk

Functions