Skip to content

Commit 86bb64e

Browse files
authored
Merge pull request #792 from KaplanOpenSource/issue641
refactor(postprocess): Improve pvOpenFOAMBase.writeCase (9 helpers)
2 parents b265090 + a35ded9 commit 86bb64e

1 file changed

Lines changed: 133 additions & 74 deletions

File tree

hera/simulations/openFoam/postProcess/pvOpenFOAMBase.py

Lines changed: 133 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -383,134 +383,193 @@ def writeCase(self, filtersDict, regularMesh, timeList=None, fieldnames=None, ts
383383
"""
384384
Write VTK filter results to parquet (unstructured) or zarr (regular) files.
385385
386-
Processes timesteps in blocks to control memory, writes temporary files,
387-
then repartitions and merges into the final output. Supports append mode
388-
(merging with existing data) and overwrite mode.
386+
Algorithm overview:
387+
1. Prepare the filesystem (clean old outputs, create directories).
388+
2. Resolve which timesteps to process.
389+
3. Stream timesteps in fixed-size blocks to temporary files (controls memory).
390+
4. Merge all temporary files into the final output per filter.
389391
390392
Parameters
391393
----------
392394
filtersDict : dict
393-
Mapping of filter name output file path.
395+
Mapping of filter name -> output file path.
394396
regularMesh : bool
395397
If True, write zarr (xarray). If False, write parquet (dask DataFrame).
396398
timeList : list or None
397-
Timesteps to process. None = all available.
399+
Timesteps to process. None = all available from the reader.
398400
fieldnames : list, optional
399401
VTK field names to extract.
400402
tsBlockNum : int
401-
Number of timesteps to accumulate before writing a temporary block.
403+
Number of timesteps to accumulate before flushing a temporary block.
402404
overwrite : bool
403405
If True, remove existing output files before writing.
404406
latestTimestamp : bool
405-
If True, only process the latest timestep.
407+
If True, only process the latest available timestep.
406408
"""
407409
logger = get_classMethod_logger(self, "writeNonRegularCase")
410+
logger.info(f"Starting writing to parquet filters {','.join(filtersDict.keys())}")
411+
412+
# Choose file extension based on mesh type: zarr for regular, parquet otherwise.
408413
slice_filext = "zarr" if regularMesh else "parquet"
414+
# In non-overwrite mode, new data is appended to any existing output.
409415
append = not overwrite
410416

411-
# Step 1: Clean old output files if overwriting.
417+
# Step 1: Prepare the filesystem for writing.
412418
if overwrite:
413419
self._removeOldOutputs(filtersDict)
414-
415-
# Step 2: Ensure output directories exist.
416420
self._ensureOutputDirs(filtersDict)
417421

418-
# Step 3: Resolve the timestep list.
422+
# Step 2: Determine which timesteps to read from the simulation.
423+
readTimesList = self._resolveTimeList(timeList, latestTimestamp)
424+
425+
# Step 3: Stream timesteps in blocks, writing each full block to a temp file.
426+
self._writeTimeStepBlocks(filtersDict, readTimesList, fieldnames,
427+
regularMesh, slice_filext, tsBlockNum)
428+
429+
# Step 4: For each filter, merge its temp files into a single final output.
430+
logger.info("Repartitioning to 100MB per partition")
431+
for filterName, outputFile in filtersDict.items():
432+
tmpFiles = self._collectTmpFiles(filterName, outputFile, slice_filext)
433+
self._mergeToFinalOutput(outputFile, tmpFiles, regularMesh, append)
434+
self._atomicReplace(outputFile)
435+
self._cleanupTmpFiles(tmpFiles)
436+
437+
# ------------------------------------------------------------------
438+
# Private helpers for writeCase
439+
# ------------------------------------------------------------------
440+
441+
def _removeOldOutputs(self, filtersDict):
442+
"""Remove existing output files or directories before overwriting."""
443+
logger = get_classMethod_logger(self, "_removeOldOutputs")
444+
logger.info("Removing the old results")
445+
for filterName, outputPath in filtersDict.items():
446+
logger.debug(f"The data for {filterName} : {outputPath}")
447+
if os.path.isfile(outputPath):
448+
logger.debug(f"\tParquet file {outputPath} is a file. Removing it")
449+
os.remove(outputPath)
450+
elif os.path.isdir(outputPath):
451+
logger.debug(f"\tParquet file {outputPath} is a directory. Removing the tree")
452+
shutil.rmtree(outputPath)
453+
454+
def _ensureOutputDirs(self, filtersDict):
455+
"""Create output directories for each filter if they do not already exist."""
456+
logger = get_classMethod_logger(self, "_ensureOutputDirs")
457+
logger.info("Making sure that the output directories exist")
458+
for filterName, outputFile in filtersDict.items():
459+
outputPath = os.path.dirname(outputFile)
460+
logger.debug(f"{filterName} for directory {outputPath}")
461+
if not os.path.isdir(outputPath):
462+
logger.debug(f"\t Does not exist. Creating {outputPath}")
463+
os.makedirs(outputPath)
464+
465+
def _resolveTimeList(self, timeList, latestTimestamp):
466+
"""Determine which timesteps to process from the reader or caller input.
467+
468+
Returns the full reader timestep list when timeList is None,
469+
or trims to only the latest entry when latestTimestamp is True.
470+
"""
471+
logger = get_classMethod_logger(self, "_resolveTimeList")
419472
readTimesList = self.reader.TimestepValues if timeList is None else timeList
473+
logger.info(f"Getting timelist {readTimesList}")
420474
if latestTimestamp and len(readTimesList) != 0:
421475
readTimesList = [readTimesList[-1]]
476+
return readTimesList
422477

423-
# Step 4: Process timesteps in blocks of tsBlockNum.
424-
# Each block is written to a temporary file to limit memory usage.
478+
def _writeTimeStepBlocks(self, filtersDict, readTimesList, fieldnames,
479+
regularMesh, slice_filext, tsBlockNum):
480+
"""Stream timesteps from the VTK pipeline, flushing to temp files in blocks.
481+
482+
Accumulates up to tsBlockNum timesteps in memory, then writes them
483+
to a numbered temporary file via writeList. Any leftover timesteps
484+
that do not fill a complete block are flushed at the end.
485+
"""
486+
logger = get_classMethod_logger(self, "_writeTimeStepBlocks")
425487
blockID = 0
426488
tempList = []
489+
427490
for filtersData in tqdm.tqdm(self.readTimeSteps(
428491
datasourcenamedict=filtersDict, timelist=readTimesList,
429492
fieldnames=fieldnames, regularMesh=regularMesh)):
430493
tempList.append(filtersData)
494+
logger.debug(f"Current dataFrames in memory {len(tempList)}")
495+
# Flush the block to disk once it reaches the target size.
431496
if len(tempList) == tsBlockNum:
432497
self.writeList(tempList, blockID, filtersDict, regularMesh, slice_filext)
433498
tempList = []
434499
blockID += 1
435-
# Write any remaining timesteps.
500+
501+
# Flush any remaining timesteps that did not fill a complete block.
436502
if len(tempList) > 0:
437503
self.writeList(tempList, blockID, filtersDict, regularMesh, slice_filext)
438504

439-
# Step 5: Merge temporary block files into final output.
440-
# Repartitions to ~100MB chunks and optionally appends to existing data.
441-
for filterName, outputFile in filtersDict.items():
442-
self._mergeTemporaryBlocks(
443-
filterName, outputFile, regularMesh, slice_filext, append
444-
)
445-
446-
def _removeOldOutputs(self, filtersDict):
447-
"""Remove existing output files/dirs before overwriting."""
448-
logger = get_classMethod_logger(self, "_removeOldOutputs")
449-
for filterName, outputPath in filtersDict.items():
450-
if os.path.isfile(outputPath):
451-
os.remove(outputPath)
452-
elif os.path.isdir(outputPath):
453-
shutil.rmtree(outputPath)
454-
455-
@staticmethod
456-
def _ensureOutputDirs(filtersDict):
457-
"""Create output directories if they don't exist."""
458-
for filterName, outputFile in filtersDict.items():
459-
outputDir = os.path.dirname(outputFile)
460-
if outputDir and not os.path.isdir(outputDir):
461-
os.makedirs(outputDir)
505+
def _collectTmpFiles(self, filterName, outputFile, slice_filext):
506+
"""Glob all numbered temporary block files produced for a given filter."""
507+
outputPath = os.path.dirname(outputFile)
508+
tmpPattern = f"tmp_{filterName.replace('.', '-')}_*.{slice_filext}"
509+
return glob.glob(os.path.join(outputPath, tmpPattern))
462510

463-
def _mergeTemporaryBlocks(self, filterName, outputFile, regularMesh, fileExt, append):
464-
"""Merge temporary block files into final output, repartition, and clean up.
511+
def _mergeToFinalOutput(self, outputFile, tmpFiles, regularMesh, append):
512+
"""Merge temporary block files into a single '.final' staging file.
465513
466-
For regular meshes: uses xarray + zarr with optional time-concatenation.
467-
For unstructured meshes: uses dask + parquet with repartitioning.
468-
Writes to a .final temp file then atomically renames.
514+
For regular meshes (zarr): opens all blocks as a lazy multi-file dataset,
515+
optionally concatenates with previously saved data, and writes to zarr.
516+
For unstructured meshes (parquet): concatenates all blocks with dask,
517+
repartitions to ~100 MB chunks indexed by time, and writes to parquet.
469518
"""
470-
logger = get_classMethod_logger(self, "_mergeTemporaryBlocks")
471-
outputPath = os.path.dirname(outputFile)
472-
# Find all temporary block files for this filter.
473-
tmpPattern = f"tmp_{filterName.replace('.', '-')}_*.{fileExt}"
474-
outputFileList = glob.glob(os.path.join(outputPath, tmpPattern))
519+
logger = get_classMethod_logger(self, "_mergeToFinalOutput")
520+
logger.debug(f"Saving all data to {outputFile}: {tmpFiles}")
475521

476-
logger.info(f"Merging {len(outputFileList)} blocks → {outputFile}")
477522
with ProgressBar():
478523
if regularMesh:
479-
# Zarr path: open all blocks as a multi-file dataset, optionally
480-
# append existing data, write to .final then rename.
481-
lazy_ds = xarray.open_mfdataset(outputFileList, chunks='auto', engine="zarr")
482-
if append and os.path.exists(outputFile):
483-
old_data = xarray.open_mfdataset(outputFile, chunks='auto', engine="zarr")
484-
lazy_ds = xarray.concat([lazy_ds, old_data], dim="time").sortby("time")
485-
try:
486-
lazy_ds.to_zarr(f"{outputFile}.final", mode='w')
487-
except NotImplementedError:
488-
# Workaround: some xarray versions need explicit rechunking.
489-
lazy_ds.chunk("auto").to_zarr(f"{outputFile}.final", mode='w')
524+
self._mergeZarr(outputFile, tmpFiles, append)
490525
else:
491-
# Parquet path: concat all blocks, optionally append existing,
492-
# repartition to 100MB, index by time, write to .final.
493-
newDataList = [dd.read_parquet(f) for f in outputFileList]
494-
if append and os.path.exists(outputFile):
495-
newDataList.append(dd.read_parquet(outputFile))
496-
dd.concat(newDataList).repartition(partition_size="100MB") \
497-
.reset_index().set_index("time") \
498-
.to_parquet(f"{outputFile}.final")
499-
500-
# Atomic replace: remove old → rename .final → output.
526+
self._mergeParquet(outputFile, tmpFiles, append)
527+
528+
def _mergeZarr(self, outputFile, tmpFiles, append):
529+
"""Concatenate temporary zarr blocks, optionally appending old data."""
530+
lazy_ds = xarray.open_mfdataset(tmpFiles, chunks='auto', engine="zarr")
531+
# If appending, include previously saved data so nothing is lost.
532+
if append and os.path.exists(outputFile):
533+
old_data = xarray.open_mfdataset(outputFile, chunks='auto', engine="zarr")
534+
lazy_ds = xarray.concat([lazy_ds, old_data], dim="time").sortby("time")
535+
try:
536+
lazy_ds.to_zarr(f"{outputFile}.final", mode='w')
537+
except NotImplementedError:
538+
# somethimes this works and sometimes the other. not clear yet when...
539+
lazy_ds.chunk("auto").to_zarr(f"{outputFile}.final", mode='w')
540+
541+
def _mergeParquet(self, outputFile, tmpFiles, append):
542+
"""Concatenate temporary parquet blocks, repartition to ~100 MB, index by time."""
543+
newDataList = [dd.read_parquet(fileName) for fileName in tmpFiles]
544+
# If appending, include the previously saved parquet data.
545+
if append and os.path.exists(outputFile):
546+
newDataList.append(dd.read_parquet(outputFile))
547+
dd.concat(newDataList).repartition(partition_size="100MB") \
548+
.reset_index() \
549+
.set_index("time") \
550+
.to_parquet(f"{outputFile}.final")
551+
552+
def _atomicReplace(self, outputFile):
553+
"""Atomically swap the '.final' staging file into the target output path.
554+
555+
Removes the old output (file or directory) first, then renames.
556+
"""
501557
if os.path.exists(outputFile):
502558
if os.path.isfile(outputFile):
503559
os.remove(outputFile)
504560
else:
505561
shutil.rmtree(outputFile)
506562
os.rename(f"{outputFile}.final", outputFile)
507563

508-
# Clean up temporary block files.
509-
for tmpFile in outputFileList:
510-
if os.path.isfile(tmpFile):
511-
os.remove(tmpFile)
564+
def _cleanupTmpFiles(self, tmpFiles):
565+
"""Remove all temporary block files after a successful merge."""
566+
logger = get_classMethod_logger(self, "_cleanupTmpFiles")
567+
logger.debug("Removing the old tmp files. ")
568+
for fileTodelete in tmpFiles:
569+
if os.path.isfile(fileTodelete):
570+
os.remove(fileTodelete)
512571
else:
513-
shutil.rmtree(tmpFile)
572+
shutil.rmtree(fileTodelete)
514573

515574
def writeList(self,theList,blockID,filtersDict,regularMesh,fileExt):
516575
"""Write a list of time step data blocks to temporary files."""

0 commit comments

Comments
 (0)