Skip to content

Commit e587d82

Browse files
authored
transpose dimensions (#4)
1 parent e4b550f commit e587d82

File tree

3 files changed

+251
-1
lines changed

3 files changed

+251
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Attributes:
148148
- [x] Write to STRDS
149149
- [x] Write to 3D raster
150150
- [x] Write to STR3DS
151-
- [ ] Transpose if dimensions are not in the expected order
151+
- [x] Transpose if dimensions are not in the expected order
152152
- [ ] Support time units for relative time
153153
- [ ] Support `end_time`
154154
- [ ] Accept writing into a specific mapset (GRASS 8.5)

src/xarray_grass/to_grass.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,12 @@ def _datarray_to_grass(
304304
# TODO: reshape to match userGRASS expected dims order
305305
try:
306306
if is_raster:
307+
data = self.transpose(data, dims, arr_type="raster")
307308
self.grass_interface.write_raster_map(data, data.name)
308309
elif is_strds:
309310
self._write_stds(data, dims)
310311
elif is_raster_3d:
312+
data = self.transpose(data, dims, arr_type="raster3d")
311313
self.grass_interface.write_raster3d_map(data, data.name)
312314
elif is_str3ds:
313315
self._write_stds(data, dims)
@@ -320,6 +322,19 @@ def _datarray_to_grass(
320322
# Restore the original region
321323
self.grass_interface.set_region(current_region)
322324

325+
def transpose(
326+
self, da: xr.DataArray, dims, arr_type: str = "raster"
327+
) -> xr.DataArray:
328+
"""Force dimension order to conform with grass expectation."""
329+
if "raster" == arr_type:
330+
return da.transpose(dims["y"], dims["x"])
331+
elif "raster3d" == arr_type:
332+
return da.transpose(dims["z"], dims["y_3d"], dims["x_3d"])
333+
else:
334+
raise ValueError(
335+
f"Unknown array type: {arr_type}. Must be 'raster' or 'raster3d'."
336+
)
337+
323338
def _write_stds(self, data: xr.DataArray, dims: Mapping):
324339
# 1. Determine the temporal coordinate and type
325340
time_coord = data[dims["start_time"]]
@@ -337,14 +352,17 @@ def _write_stds(self, data: xr.DataArray, dims: Mapping):
337352
# 2.5 determine if 2D or 3D
338353
is_3d = False
339354
stds_type = "strds"
355+
arr_type = "raster"
340356
if len(data.isel({dims["start_time"]: 0}).dims) == 3:
341357
is_3d = True
342358
stds_type = "str3ds"
359+
arr_type = "raster3d"
343360

344361
# 3. Loop through the time dim:
345362
map_list = []
346363
for index, time in enumerate(time_coord):
347364
darray = data.sel({dims["start_time"]: time})
365+
darray = self.transpose(darray, dims, arr_type=arr_type)
348366
nd_array = darray.values
349367
# 3.1 Write each map individually
350368
raster_name = f"{data.name}_{temporal_type}_{index}"

tests/test_tograss.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,235 @@ def test_dims_mapping(
779779
)
780780
assert int(info["rows"]) == img_height
781781
assert int(info["cols"]) == img_width
782+
783+
def test_dimension_transposition(
784+
self,
785+
temp_gisdb,
786+
grass_i: GrassInterface,
787+
):
788+
"""Test that to_grass() correctly transposes dimensions to GRASS format.
789+
790+
Creates DataArrays with standard dimension names but in non-standard order,
791+
and verifies they are correctly transposed when written to GRASS.
792+
"""
793+
session_crs_wkt = grass_i.get_crs_wkt_str()
794+
target_mapset_name = temp_gisdb.mapset
795+
mapset_path_obj = (
796+
Path(temp_gisdb.gisdb) / temp_gisdb.project / target_mapset_name
797+
)
798+
mapset_arg = str(mapset_path_obj)
799+
800+
# Define expected dimensions for verification
801+
height_2d, width_2d = 5, 7
802+
depth_3d, height_3d, width_3d = 3, 4, 6
803+
num_times_strds = 2
804+
height_strds, width_strds = 6, 5
805+
num_times_str3ds = 2
806+
depth_str3ds, height_str3ds, width_str3ds = 3, 5, 4
807+
808+
# 1. Test 2D Raster: Create with standard dims but transpose before writing
809+
raster2d_name = "test_transpose_2d"
810+
da_2d = create_sample_dataarray(
811+
dims_spec={
812+
"y": np.arange(height_2d, dtype=float),
813+
"x": np.arange(width_2d, dtype=float),
814+
},
815+
shape=(height_2d, width_2d),
816+
crs_wkt=session_crs_wkt,
817+
name=raster2d_name,
818+
fill_value_generator=lambda s: np.arange(s[0] * s[1])
819+
.reshape(s)
820+
.astype(float),
821+
)
822+
# Transpose to non-standard order (x, y instead of y, x)
823+
da_2d = da_2d.transpose("x", "y")
824+
assert da_2d.dims == ("x", "y"), f"Expected dims ('x', 'y'), got {da_2d.dims}"
825+
assert da_2d.shape == (width_2d, height_2d)
826+
827+
# 2. Test 3D Raster: Create with z,y,x then transpose to x,z,y
828+
raster3d_name = "test_transpose_3d"
829+
res3 = 1000
830+
da_3d = create_sample_dataarray(
831+
dims_spec={
832+
"z": np.arange(depth_3d, dtype=float),
833+
"y": np.linspace(220000, 220000 + (height_3d - 1) * res3, height_3d),
834+
"x": np.linspace(630000, 630000 + (width_3d - 1) * res3, width_3d),
835+
},
836+
shape=(depth_3d, height_3d, width_3d),
837+
crs_wkt=session_crs_wkt,
838+
name=raster3d_name,
839+
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2])
840+
.reshape(s)
841+
.astype(float),
842+
)
843+
# Transpose to non-standard order (x_3d, z, y_3d)
844+
da_3d = da_3d.transpose("x_3d", "z", "y_3d")
845+
assert da_3d.dims == ("x_3d", "z", "y_3d")
846+
assert da_3d.shape == (width_3d, depth_3d, height_3d)
847+
848+
# 3. Test STRDS: Create with start_time,y,x then transpose to x,y,start_time
849+
strds_name = "test_transpose_strds"
850+
da_strds = create_sample_dataarray(
851+
dims_spec={
852+
"start_time": np.arange(1, num_times_strds + 1),
853+
"y": np.arange(height_strds, dtype=float),
854+
"x": np.arange(width_strds, dtype=float),
855+
},
856+
shape=(num_times_strds, height_strds, width_strds),
857+
crs_wkt=session_crs_wkt,
858+
name=strds_name,
859+
time_dim_type="relative",
860+
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2])
861+
.reshape(s)
862+
.astype(float),
863+
)
864+
# Transpose to non-standard order (x, y, start_time)
865+
da_strds = da_strds.transpose("x", "y", "start_time")
866+
assert da_strds.dims == ("x", "y", "start_time")
867+
assert da_strds.shape == (width_strds, height_strds, num_times_strds)
868+
869+
# 4. Test STR3DS: Create with time,z,y,x then transpose to y_3d,time,x_3d,z
870+
str3ds_name = "test_transpose_str3ds"
871+
res3_str3ds = 1000
872+
da_str3ds = create_sample_dataarray(
873+
dims_spec={
874+
"time": np.arange(1, num_times_str3ds + 1),
875+
"z": np.arange(depth_str3ds, dtype=float),
876+
"y": np.linspace(
877+
220000, 220000 + (height_str3ds - 1) * res3_str3ds, height_str3ds
878+
),
879+
"x": np.linspace(
880+
630000, 630000 + (width_str3ds - 1) * res3_str3ds, width_str3ds
881+
),
882+
},
883+
shape=(num_times_str3ds, depth_str3ds, height_str3ds, width_str3ds),
884+
crs_wkt=session_crs_wkt,
885+
name=str3ds_name,
886+
time_dim_type="relative",
887+
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2] * s[3])
888+
.reshape(s)
889+
.astype(float),
890+
)
891+
# Transpose to non-standard order (y_3d, time, x_3d, z)
892+
da_str3ds = da_str3ds.transpose("y_3d", "time", "x_3d", "z")
893+
assert da_str3ds.dims == ("y_3d", "time", "x_3d", "z")
894+
assert da_str3ds.shape == (
895+
height_str3ds,
896+
num_times_str3ds,
897+
width_str3ds,
898+
depth_str3ds,
899+
)
900+
901+
# Write all DataArrays to GRASS
902+
raster2d_id = grass_i.get_id_from_name(raster2d_name)
903+
raster3d_id = grass_i.get_id_from_name(raster3d_name)
904+
strds_id = grass_i.get_id_from_name(strds_name)
905+
str3ds_id = grass_i.get_id_from_name(str3ds_name)
906+
907+
try:
908+
# Write 2D raster
909+
to_grass(dataset=da_2d, mapset=mapset_arg)
910+
911+
# Write 3D raster
912+
to_grass(dataset=da_3d, mapset=mapset_arg)
913+
914+
# Write STRDS
915+
to_grass(dataset=da_strds, mapset=mapset_arg)
916+
917+
# Write STR3DS
918+
to_grass(
919+
dataset=da_str3ds,
920+
mapset=mapset_arg,
921+
dims={str3ds_name: {"start_time": "time"}},
922+
)
923+
924+
# Verify 2D Raster dimensions
925+
info_2d = gs.parse_command("r.info", map=raster2d_id, flags="g", quiet=True)
926+
assert int(info_2d["rows"]) == height_2d, (
927+
f"2D Raster rows mismatch: expected {height_2d}, got {info_2d['rows']}"
928+
)
929+
assert int(info_2d["cols"]) == width_2d, (
930+
f"2D Raster cols mismatch: expected {width_2d}, got {info_2d['cols']}"
931+
)
932+
933+
# Verify 3D Raster dimensions
934+
info_3d = gs.parse_command(
935+
"r3.info", map=raster3d_id, flags="g", quiet=True
936+
)
937+
assert int(info_3d["depths"]) == depth_3d, (
938+
f"3D Raster depths mismatch: expected {depth_3d}, got {info_3d['depths']}"
939+
)
940+
assert int(info_3d["rows"]) == height_3d, (
941+
f"3D Raster rows mismatch: expected {height_3d}, got {info_3d['rows']}"
942+
)
943+
assert int(info_3d["cols"]) == width_3d, (
944+
f"3D Raster cols mismatch: expected {width_3d}, got {info_3d['cols']}"
945+
)
946+
947+
# Verify STRDS
948+
strds_maps = grass_i.list_maps_in_strds(strds_id)
949+
assert len(strds_maps) == num_times_strds, (
950+
f"STRDS map count mismatch: expected {num_times_strds}, got {len(strds_maps)}"
951+
)
952+
# Check dimensions of first map in STRDS
953+
first_map_id = strds_maps[0].id
954+
info_strds = gs.parse_command(
955+
"r.info", map=first_map_id, flags="g", quiet=True
956+
)
957+
assert int(info_strds["rows"]) == height_strds, (
958+
f"STRDS map rows mismatch: expected {height_strds}, got {info_strds['rows']}"
959+
)
960+
assert int(info_strds["cols"]) == width_strds, (
961+
f"STRDS map cols mismatch: expected {width_strds}, got {info_strds['cols']}"
962+
)
963+
964+
# Verify STR3DS
965+
str3ds_maps = grass_i.list_maps_in_str3ds(str3ds_id)
966+
assert len(str3ds_maps) == num_times_str3ds, (
967+
f"STR3DS map count mismatch: expected {num_times_str3ds}, got {len(str3ds_maps)}"
968+
)
969+
# Check dimensions of first map in STR3DS
970+
first_map_3d_id = str3ds_maps[0].id
971+
info_str3ds = gs.parse_command(
972+
"r3.info", map=first_map_3d_id, flags="g", quiet=True
973+
)
974+
assert int(info_str3ds["depths"]) == depth_str3ds, (
975+
f"STR3DS map depths mismatch: expected {depth_str3ds}, got {info_str3ds['depths']}"
976+
)
977+
assert int(info_str3ds["rows"]) == height_str3ds, (
978+
f"STR3DS map rows mismatch: expected {height_str3ds}, got {info_str3ds['rows']}"
979+
)
980+
assert int(info_str3ds["cols"]) == width_str3ds, (
981+
f"STR3DS map cols mismatch: expected {width_str3ds}, got {info_str3ds['cols']}"
982+
)
983+
984+
finally:
985+
# Cleanup
986+
try:
987+
gs.run_command(
988+
"g.remove", flags="f", type="raster", name=raster2d_id, quiet=True
989+
)
990+
except CalledModuleError:
991+
pass
992+
try:
993+
gs.run_command(
994+
"g.remove",
995+
flags="f",
996+
type="raster_3d",
997+
name=raster3d_id,
998+
quiet=True,
999+
)
1000+
except CalledModuleError:
1001+
pass
1002+
try:
1003+
gs.run_command(
1004+
"t.remove", inputs=strds_id, type="strds", flags="rfd", quiet=True
1005+
)
1006+
except CalledModuleError:
1007+
pass
1008+
try:
1009+
gs.run_command(
1010+
"t.remove", inputs=str3ds_id, type="str3ds", flags="rfd", quiet=True
1011+
)
1012+
except CalledModuleError:
1013+
pass

0 commit comments

Comments
 (0)