This repository was archived by the owner on Dec 24, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandom_patches.py
More file actions
79 lines (56 loc) · 2.2 KB
/
random_patches.py
File metadata and controls
79 lines (56 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from argparse import ArgumentParser
import numpy as np
from collections import namedtuple
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.io import imread
from skimage.transform import resize
from skimage.color import gray2rgb, rgb2gray
from skimage.filters import scharr
from sklearn.feature_extraction.image import extract_patches_2d
from cropping.edge_stats import get_edge_statistics
from cropping.shrink import crop
from cropping.normalize import normalize
import tensorflow as tf
def get_variance(patch: np.ndarray) -> float:
v = np.var(patch) # type: float
return -v
def _int64_feature(value) -> tf.train.Feature:
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value) -> tf.train.Feature:
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def main():
parser = ArgumentParser()
parser.add_argument('file', metavar='path', help='The file to load.')
args = parser.parse_args()
image = imread(args.file, as_grey=False)
image = gray2rgb(image)
# apply energy cropping
energy = scharr(rgb2gray(image))
energy = normalize(energy)
stats = get_edge_statistics(energy, edge_width=50)
# crop the image
cropped = crop(image, energy, threshold=None, stats=stats)
# extract the patches
patch_count = 16
patch_size = 128
patches = extract_patches_2d(cropped, (patch_size, patch_size), patch_count)
# combine the patches into one image for preview
columns = np.ceil(np.sqrt(patch_count))
rows = np.ceil(patch_count / columns)
canvas = np.zeros((rows * patch_size, columns * patch_size, 3), image.dtype)
for i, patch in enumerate(patches):
column = int(np.mod(i, columns) * patch_size)
row = int(np.floor(i / columns) * patch_size)
canvas[row:row+patch_size, column:column+patch_size, :] = patch
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 7), sharex=True, sharey=True)
plt.gray()
with sns.axes_style('white'):
ax.imshow(canvas)
ax.grid(False)
ax.axis([0, columns*patch_size, rows*patch_size, 0])
fig.tight_layout()
sns.despine()
plt.show()
if __name__ == '__main__':
main()