Skip to content

Feature/combine zarr to netcdf #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ raw-data/
calibration/*.ecs
calibration/*.xlsx

# the .zarr files Reka wanted to be exported as .nc
rekaexport/*
netcdf_output/*


177 changes: 177 additions & 0 deletions combine-zarr-convert-netcdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import os
import logging
import xarray as xr
import numpy as np
from pathlib import Path
from collections import defaultdict

from saildrone.store import PostgresDB, FileSegmentService
from saildrone.process.plot import plot_sv_data
from saildrone.process.concat import merge_location_data
from echopype.commongrid import compute_NASC, compute_MVBS

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define directories
ZARR_ROOT = Path("./rekaexport")
OUTPUT_DIR = Path("./netcdf_output")
PLOT_OUTPUT_DIR = Path("./echogram_plots")
OUTPUT_DIR.mkdir(exist_ok=True)
PLOT_OUTPUT_DIR.mkdir(exist_ok=True)

# Function to find valid file IDs
def find_valid_ids(base_directory):
"""Finds valid dataset IDs in subdirectories and retrieves metadata."""
grouped_files = defaultdict(lambda: {"normal": [], "denoised": [], "metadata": []})

with PostgresDB() as db:
file_service = FileSegmentService(db)

for survey_dir in base_directory.iterdir():
if survey_dir.is_dir():
for file_id_dir in survey_dir.iterdir():
if file_id_dir.is_dir():
file_id = file_id_dir.name
normal_zarr = file_id_dir / f"{file_id}.zarr"
denoised_zarr = file_id_dir / f"{file_id}_denoised.zarr"

if normal_zarr.exists() and denoised_zarr.exists():
metadata = file_service.get_file_metadata(file_id)

file_freqs = metadata.get("file_freqs", "unknown")
category = "short_pulse" if file_freqs == "38000.0,200000.0" else "long_pulse" if file_freqs == "38000.0" else "exported_ds"

print(f"Found valid file ID: {file_id} | file_freqs: {file_freqs} | categ ({category})")
grouped_files[category]["normal"].append(normal_zarr)
grouped_files[category]["denoised"].append(denoised_zarr)
grouped_files[category]["metadata"].append(metadata)

return grouped_files

# Function to combine datasets per frequency
def combine_zarr_files(zarr_files, metadata):
"""Loads and combines multiple Zarr files while ensuring consistent dimension alignment."""
datasets = [xr.open_zarr(f) for f in zarr_files]

# Merge location data for each dataset
for i, ds in enumerate(datasets):
location_data = metadata[i].get("location_data", {})
# datasets[i] = merge_location_data(ds, location_data)

if "time" in ds.dims:
datasets[i] = ds.drop_dims("time", errors="ignore")

if "filenames" in ds.dims:
datasets[i] = ds.drop_dims("filenames", errors="ignore")

sorted_datasets = sorted(datasets, key=lambda ds: ds["ping_time"].min().values)

# sorted_datasets = [
# ds.rename({"source_filenames": f"source_filenames_{i}"})
# for i, ds in enumerate(sorted_datasets)
# ]

# Concatenate along the specified dimension
concatenated_ds = xr.merge(sorted_datasets)

if "ping_time" in concatenated_ds.dims and "time" in concatenated_ds.dims:
concatenated_ds = concatenated_ds.drop_vars("time", errors="ignore")

return concatenated_ds

# Function to save as NetCDF
def save_to_netcdf(dataset, output_path):
"""Saves dataset as NetCDF."""
logging.info(f"Saving NetCDF file to {output_path}")
dataset.to_netcdf(output_path, format="NETCDF4")

# Function to plot NetCDF file
def plot_netcdf(netcdf_path):
"""Loads a NetCDF file and plots the echogram."""
ds = xr.open_dataset(netcdf_path)
file_base_name = netcdf_path.stem

print(f"Plotting echograms for {file_base_name}\n{ds}")

plot_sv_data(ds, file_base_name=file_base_name, output_path=PLOT_OUTPUT_DIR, depth_var="depth")
logging.info(f"Plots saved in: {PLOT_OUTPUT_DIR}")

# Main processing loop
if __name__ == "__main__":
grouped_files = find_valid_ids(ZARR_ROOT)

for category, data in grouped_files.items():

print(f"\n\n---- Processing category: {category}----\n\n")
if data["normal"]:
normal_ds = combine_zarr_files(data["normal"], data["metadata"])
output_file = OUTPUT_DIR / f"{category}.nc"
output_file_mvbs = OUTPUT_DIR / f"{category}_MVBS.nc"
output_file_nasc = OUTPUT_DIR / f"{category}_NASC.nc"
# netcdf
save_to_netcdf(normal_ds, output_file)
plot_netcdf(output_file)

# MVBS
# ds_MVBS = compute_MVBS(
# normal_ds,
# range_var="depth",
# range_bin='1m', # in meters
# ping_time_bin='5s', # in seconds
# )
# save_to_netcdf(ds_MVBS, output_file_mvbs)
# plot_netcdf(output_file_mvbs)

# # NASC
# ds_NASC = compute_NASC(
# normal_ds,
# range_bin="10m",
# dist_bin="0.5nmi"
# )
# # Log-transform the NASC values for plotting
# ds_NASC["NASC_log"] = 10 * np.log10(ds_NASC["NASC"])
# ds_NASC["NASC_log"].attrs = {
# "long_name": "Log of NASC",
# "units": "m2 nmi-2"
# }
# save_to_netcdf(ds_NASC, output_file_nasc)
# plot_netcdf(output_file_nasc)



if data["denoised"]:
denoised_ds = combine_zarr_files(data["denoised"], data["metadata"])
output_file_denoised = OUTPUT_DIR / f"{category}_denoised.nc"
output_file_denoised_mvbs = OUTPUT_DIR / f"{category}_denoised_MVBS.nc"
output_file_denoised_nasc = OUTPUT_DIR / f"{category}_denoised_NASC.nc"

# MVBS
# ds_MVBS = compute_MVBS(
# denoised_ds,
# range_var="depth",
# range_bin='1m', # in meters
# ping_time_bin='5s', # in seconds
# )
# save_to_netcdf(ds_MVBS, output_file_denoised_mvbs)
# plot_netcdf(output_file_denoised_mvbs)

# NASC
# ds_NASC = compute_NASC(
# denoised_ds,
# range_bin="10m",
# dist_bin="0.5nmi"
# )
# # Log-transform the NASC values for plotting
# ds_NASC["NASC_log"] = 10 * np.log10(ds_NASC["NASC"])
# ds_NASC["NASC_log"].attrs = {
# "long_name": "Log of NASC",
# "units": "m2 nmi-2"
# }
# save_to_netcdf(ds_NASC, output_file_denoised_nasc)
# plot_netcdf(output_file_denoised_nasc)

save_to_netcdf(denoised_ds, output_file_denoised)
plot_netcdf(output_file_denoised)

logging.info("Processing complete.")
86 changes: 79 additions & 7 deletions saildrone/process/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,91 @@


def merge_location_data(dataset: xr.Dataset, location_data):
"""Merge location data into the dataset while ensuring it's a variable, not just an attribute."""
# Convert location_data to a Pandas DataFrame
location_df = pd.DataFrame(location_data)

# Convert timestamp strings to datetime objects
location_df['dt'] = pd.to_datetime(location_df['dt'])

# Create xarray variables from the location data
dataset['latitude'] = xr.DataArray(location_df['lat'].values, dims='time',
coords={'time': location_df['dt'].values})
dataset['longitude'] = xr.DataArray(location_df['lon'].values, dims='time',
coords={'time': location_df['dt'].values})
dataset['speed_knots'] = xr.DataArray(location_df['knt'].values, dims='time',
coords={'time': location_df['dt'].values})
# Determine which time dimension to use
time_dim = "ping_time" if "ping_time" in dataset.dims else "time" if "time" in dataset.dims else None
if not time_dim:
return dataset # Return without merging if no time dimension exists

# Interpolate location data to match dataset time
target_times = dataset[time_dim].values

lat_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["lat"]
)

lon_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["lon"]
)

speed_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["knt"]
)

# Ensure latitude is stored as a variable, not just an attribute
dataset['latitude'] = xr.DataArray(lat_interp, dims=time_dim, coords={time_dim: target_times})
dataset['longitude'] = xr.DataArray(lon_interp, dims=time_dim, coords={time_dim: target_times})
dataset['speed_knots'] = xr.DataArray(speed_interp, dims=time_dim, coords={time_dim: target_times})

# Debugging: Print dataset variables after merging

return dataset


def xmerge_location_data(dataset: xr.Dataset, location_data):
"""Merge location data into the dataset while ensuring time alignment using interpolation."""
# Convert location_data to a Pandas DataFrame
location_df = pd.DataFrame(location_data)

# Convert timestamp strings to datetime objects
location_df['dt'] = pd.to_datetime(location_df['dt'])

if "ping_time" in dataset.dims:
# Interpolate location data to match 'ping_time'
target_times = dataset["ping_time"].values

lat_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["lat"]
)

lon_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["lon"]
)

speed_interp = np.interp(
np.array(pd.to_datetime(target_times).astype(int)),
np.array(location_df["dt"].astype(int)),
location_df["knt"]
)

dataset['latitude'] = xr.DataArray(lat_interp, dims="ping_time", coords={"ping_time": target_times})
dataset['longitude'] = xr.DataArray(lon_interp, dims="ping_time", coords={"ping_time": target_times})
dataset['speed_knots'] = xr.DataArray(speed_interp, dims="ping_time", coords={"ping_time": target_times})

else:
# Default behavior: Assign location data based on its own timestamps
dataset['latitude'] = xr.DataArray(location_df['lat'].values, dims='time',
coords={'time': location_df['dt'].values})
dataset['longitude'] = xr.DataArray(location_df['lon'].values, dims='time',
coords={'time': location_df['dt'].values})
dataset['speed_knots'] = xr.DataArray(location_df['knt'].values, dims='time',
coords={'time': location_df['dt'].values})

return dataset

Expand Down
38 changes: 35 additions & 3 deletions saildrone/store/filesegment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,38 @@ def is_file_processed(self, file_name: str) -> bool:
self.db.cursor.execute(f'SELECT id FROM {self.table_name} WHERE file_name=%s AND processed=TRUE', (file_name,))
return self.db.cursor.fetchone() is not None

def get_file_metadata(self, file_name: str):
"""
Get metadata for a file from the database.
Parameters
----------
file_name : str
The name of the file to check.

Returns
-------
dict
A dictionary containing information about the file.
"""
self.db.cursor.execute(f'SELECT id, size, converted, processed, location, file_name, location_data, file_freqs, file_start_time, file_end_time FROM {self.table_name} WHERE file_name=%s', (file_name,))
row = self.db.cursor.fetchone()

if row:
return {
'id': row[0],
'size': row[1],
'converted': row[2],
'processed': row[3],
'location': row[4],
'file_name': row[5],
'location_data': row[6],
'file_freqs': row[7],
'file_start_time': row[8],
'file_end_time': row[9]
}

return None

def get_file_info(self, file_name: str):
"""
Get information about a file from the database.
Expand Down Expand Up @@ -195,9 +227,9 @@ def insert_file_record(
"""
self.db.cursor.execute('''
INSERT INTO files (
file_name, size, location, processed, converted, last_modified, file_npings, file_nsamples, file_start_time,
file_end_time, file_freqs, file_start_depth, file_end_depth, file_start_lat, file_start_lon,
file_end_lat, file_end_lon, echogram_files, failed, error_details, location_data, processing_time_ms,
file_name, size, location, processed, converted, last_modified, file_npings, file_nsamples, file_start_time,
file_end_time, file_freqs, file_start_depth, file_end_depth, file_start_lat, file_start_lon,
file_end_lat, file_end_lon, echogram_files, failed, error_details, location_data, processing_time_ms,
survey_db_id, downloaded
) VALUES (%s, %s, %s, FALSE, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id
''', (file_name, size, location, converted, last_modified, file_npings, file_nsamples, file_start_time,
Expand Down