-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreal_generator.py
More file actions
63 lines (48 loc) · 1.94 KB
/
real_generator.py
File metadata and controls
63 lines (48 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
'''
Hard-coded generator that generates voxel maps
'''
import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics
from torch import Tensor
import os
import dataloader
from voxelcnn.datasets import MinecraftTokenizer
HOME_DIR = os.path.dirname(os.path.abspath(__file__))
class RealOccupancyGenerator:
def __init__(self, config, stored_maps_file=None):
self.config = config
tokenizer = MinecraftTokenizer(config, self.config.data.air_not_air)
train_ds, valid_ds = dataloader.get_dataloaders(
config, tokenizer)
self.train_ds = train_ds
self.valid_ds = valid_ds
if stored_maps_file is None:
self.create_maps_file(os.path.join(HOME_DIR, '32real_occupancy_maps.pth'))
self.maps = torch.load(os.path.join(HOME_DIR, '32real_occupancy_maps.pth'))
def create_maps_file(self, file_path):
binary_occupancy_maps = []
for batch in self.train_ds:
occupancy_map = (batch > 0).long()
binary_occupancy_maps.append(occupancy_map)
for batch in self.valid_ds:
occupancy_map = (batch > 0).long()
binary_occupancy_maps.append(occupancy_map)
occupancy_tensor = torch.concat(binary_occupancy_maps, dim=0) #(N, X, Y, Z)
print(f"Total number of houses: {occupancy_tensor.size(0)}")
torch.save(occupancy_tensor, file_path)
print(f"Saved tensor to {file_path}")
def get_random_batch(self, batch_size):
rand_idxs = torch.randint(low=0, high=len(self.maps), size=(batch_size,))
return self.maps[rand_idxs]
@hydra.main(version_base=None, config_path='configs',
config_name='config')
def main(config):
OccGen = RealOccupancyGenerator(config, stored_maps_file=None)
batch = OccGen.get_random_batch(batch_size=16)
pass
if __name__ == "__main__":
main()