-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_preprocessing.py
31 lines (26 loc) · 1.01 KB
/
image_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Image Preprocessing
# Importing the libraries
import numpy as np
#from scipy.misc import imresize
from PIL import Image
from gym.core import ObservationWrapper
from gym.spaces.box import Box
# Preprocessing the Images
class ImageProcessing(ObservationWrapper):
def __init__(self, env, height = 64, width = 64, grayscale = True, crop = lambda img: img):
super(ImageProcessing, self).__init__(env)
self.img_size = (height, width)
self.grayscale = grayscale
self.crop = crop
n_colors = 1 if self.grayscale else 3
self.observation_space = Box(0.0, 1.0, [n_colors, height, width])
def _observation(self, img):
img = self.crop(img)
#imresize deprecated and so use Pillow
#img = imresize(img, self.img_size)
img = np.array(Image.fromarray(img).resize(self.img_size))
if self.grayscale:
img = img.mean(-1, keepdims = True)
img = np.transpose(img, (2, 0, 1))
img = img.astype('float32') / 255.
return img