-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathflyingchairs.py
39 lines (33 loc) · 1.21 KB
/
flyingchairs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os.path
import glob
from .listdataset import ListDataset
from .util import split2list
def make_dataset(dir, split=None, split_save_path=None):
"""Will search for triplets that go by the pattern '[name]_img1.ppm [name]_img2.ppm [name]_flow.flo'"""
images = []
for flow_map in sorted(glob.glob(os.path.join(dir, "*_flow.flo"))):
flow_map = os.path.basename(flow_map)
root_filename = flow_map[:-9]
img1 = root_filename + "_img1.ppm"
img2 = root_filename + "_img2.ppm"
if not (
os.path.isfile(os.path.join(dir, img1))
and os.path.isfile(os.path.join(dir, img2))
):
continue
images.append([[img1, img2], flow_map])
return split2list(images, split, split_save_path, default_split=0.97)
def flying_chairs(
root,
transform=None,
target_transform=None,
co_transform=None,
split=None,
split_save_path=None,
):
train_list, test_list = make_dataset(root, split, split_save_path)
train_dataset = ListDataset(
root, train_list, transform, target_transform, co_transform
)
test_dataset = ListDataset(root, test_list, transform, target_transform)
return train_dataset, test_dataset