-
Notifications
You must be signed in to change notification settings - Fork 0
/
rgb_hsv.py
98 lines (78 loc) · 2.57 KB
/
rgb_hsv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
# Time : 2022/10/4 19:21
# Author : Regulus
# FileName: rgb_hsv.py
# Explain:
# Software: PyCharm
"""
Pytorch implementation of RGB convert to HSV, and HSV convert to RGB,
RGB or HSV's shape: (B * C * H * W)
RGB or HSV's range: [0, 1)
"""
import torch
from torch import nn
class RGB_HSV(nn.Module):
def __init__(self, eps=1e-8):
super(RGB_HSV, self).__init__()
self.eps = eps
def rgb_to_hsv(self, img):
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(img.device)
hue[img[:, 2] == img.max(1)[0]] = 4.0 + ((img[:, 0] - img[:, 1]) / (img.max(1)[0] - img.min(1)[0] + self.eps))[
img[:, 2] == img.max(1)[0]]
hue[img[:, 1] == img.max(1)[0]] = 2.0 + ((img[:, 2] - img[:, 0]) / (img.max(1)[0] - img.min(1)[0] + self.eps))[
img[:, 1] == img.max(1)[0]]
hue[img[:, 0] == img.max(1)[0]] = (0.0 + ((img[:, 1] - img[:, 2]) / (img.max(1)[0] - img.min(1)[0] + self.eps))[
img[:, 0] == img.max(1)[0]]) % 6
hue[img.min(1)[0] == img.max(1)[0]] = 0.0
hue = hue / 6
saturation = (img.max(1)[0] - img.min(1)[0]) / (img.max(1)[0] + self.eps)
saturation[img.max(1)[0] == 0] = 0
value = img.max(1)[0]
hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)
hsv = torch.cat([hue, saturation, value], dim=1)
return hsv
def hsv_to_rgb(self, hsv):
h, s, v = hsv[:, 0, :, :], hsv[:, 1, :, :], hsv[:, 2, :, :]
# 对出界值的处理
h = h % 1
s = torch.clamp(s, 0, 1)
v = torch.clamp(v, 0, 1)
r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)
hi = torch.floor(h * 6)
f = h * 6 - hi
p = v * (1 - s)
q = v * (1 - (f * s))
t = v * (1 - ((1 - f) * s))
hi0 = hi == 0
hi1 = hi == 1
hi2 = hi == 2
hi3 = hi == 3
hi4 = hi == 4
hi5 = hi == 5
r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]
r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]
r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]
r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]
r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]
r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]
r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
return rgb