forked from learning3d/assignment3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampler.py
More file actions
46 lines (36 loc) · 1.47 KB
/
sampler.py
File metadata and controls
46 lines (36 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import math
from typing import List
import torch
from ray_utils import RayBundle
from pytorch3d.renderer.cameras import CamerasBase
import pdb
# Sampler which implements stratified (uniform) point sampling along rays
class StratifiedRaysampler(torch.nn.Module):
def __init__(
self,
cfg
):
super().__init__()
self.n_pts_per_ray = cfg.n_pts_per_ray
self.min_depth = cfg.min_depth
self.max_depth = cfg.max_depth
def forward(
self,
ray_bundle,
):
# TODO (1.4): Compute z values for self.n_pts_per_ray points uniformly sampled between [near, far]
n_rays = ray_bundle.origins.shape[0]
# z_vals = ((self.min_depth - self.max_depth) * torch.rand(size = (n_rays,self.n_pts_per_ray) ) + self.max_depth).cuda() #TODO: friggin doubt about scaling distributions
z_vals = torch.linspace(self.min_depth, self.max_depth, self.n_pts_per_ray).cuda()
z_vals = z_vals.unsqueeze(-1).unsqueeze(0).repeat(n_rays,1,1)
# TODO (1.4): Sample points from z values
ray_bundle.directions = ray_bundle.directions.unsqueeze(1).cuda()
sample_points = z_vals@ray_bundle.directions + ray_bundle.origins.unsqueeze(1).repeat(1,z_vals.shape[1],1).cuda()
# Return
return ray_bundle._replace(
sample_points=sample_points,
sample_lengths=z_vals * torch.ones_like(sample_points[..., :1]),
)
sampler_dict = {
'stratified': StratifiedRaysampler
}