LogicGoInfotechSpaces commited on
Commit
b6181ba
·
1 Parent(s): 22f58a9

Integrate CCO colorization models (eccv16 and siggraph17) - Add CCO colorizers module from kinsung/cco - Update /colorize endpoint to support model selection parameter - Add scikit-image dependency - Maintain backward compatibility with existing GAN model - Update MongoDB logging to track model type used

Browse files
app/colorizers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base_color import *
3
+ from .eccv16 import *
4
+ from .siggraph17 import *
5
+ from .util import *
6
+
app/colorizers/base_color.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class BaseColor(nn.Module):
6
+ def __init__(self):
7
+ super(BaseColor, self).__init__()
8
+
9
+ self.l_cent = 50.
10
+ self.l_norm = 100.
11
+ self.ab_norm = 110.
12
+
13
+ def normalize_l(self, in_l):
14
+ return (in_l-self.l_cent)/self.l_norm
15
+
16
+ def unnormalize_l(self, in_l):
17
+ return in_l*self.l_norm + self.l_cent
18
+
19
+ def normalize_ab(self, in_ab):
20
+ return in_ab/self.ab_norm
21
+
22
+ def unnormalize_ab(self, in_ab):
23
+ return in_ab*self.ab_norm
24
+
app/colorizers/eccv16.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from IPython import embed
6
+
7
+ from .base_color import *
8
+
9
+ class ECCVGenerator(BaseColor):
10
+ def __init__(self, norm_layer=nn.BatchNorm2d):
11
+ super(ECCVGenerator, self).__init__()
12
+
13
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16
+ model1+=[nn.ReLU(True),]
17
+ model1+=[norm_layer(64),]
18
+
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+
25
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26
+ model3+=[nn.ReLU(True),]
27
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[norm_layer(256),]
32
+
33
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34
+ model4+=[nn.ReLU(True),]
35
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36
+ model4+=[nn.ReLU(True),]
37
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[norm_layer(512),]
40
+
41
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42
+ model5+=[nn.ReLU(True),]
43
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44
+ model5+=[nn.ReLU(True),]
45
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46
+ model5+=[nn.ReLU(True),]
47
+ model5+=[norm_layer(512),]
48
+
49
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50
+ model6+=[nn.ReLU(True),]
51
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52
+ model6+=[nn.ReLU(True),]
53
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54
+ model6+=[nn.ReLU(True),]
55
+ model6+=[norm_layer(512),]
56
+
57
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58
+ model7+=[nn.ReLU(True),]
59
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60
+ model7+=[nn.ReLU(True),]
61
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62
+ model7+=[nn.ReLU(True),]
63
+ model7+=[norm_layer(512),]
64
+
65
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66
+ model8+=[nn.ReLU(True),]
67
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68
+ model8+=[nn.ReLU(True),]
69
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70
+ model8+=[nn.ReLU(True),]
71
+
72
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73
+
74
+ self.model1 = nn.Sequential(*model1)
75
+ self.model2 = nn.Sequential(*model2)
76
+ self.model3 = nn.Sequential(*model3)
77
+ self.model4 = nn.Sequential(*model4)
78
+ self.model5 = nn.Sequential(*model5)
79
+ self.model6 = nn.Sequential(*model6)
80
+ self.model7 = nn.Sequential(*model7)
81
+ self.model8 = nn.Sequential(*model8)
82
+
83
+ self.softmax = nn.Softmax(dim=1)
84
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86
+
87
+ def forward(self, input_l):
88
+ conv1_2 = self.model1(self.normalize_l(input_l))
89
+ conv2_2 = self.model2(conv1_2)
90
+ conv3_3 = self.model3(conv2_2)
91
+ conv4_3 = self.model4(conv3_3)
92
+ conv5_3 = self.model5(conv4_3)
93
+ conv6_3 = self.model6(conv5_3)
94
+ conv7_3 = self.model7(conv6_3)
95
+ conv8_3 = self.model8(conv7_3)
96
+ out_reg = self.model_out(self.softmax(conv8_3))
97
+
98
+ return self.unnormalize_ab(self.upsample4(out_reg))
99
+
100
+ def eccv16(pretrained=True):
101
+ model = ECCVGenerator()
102
+ if(pretrained):
103
+ import torch.utils.model_zoo as model_zoo
104
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105
+ return model
app/colorizers/siggraph17.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_color import *
5
+
6
+ class SIGGRAPHGenerator(BaseColor):
7
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8
+ super(SIGGRAPHGenerator, self).__init__()
9
+
10
+ # Conv1
11
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12
+ model1+=[nn.ReLU(True),]
13
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[norm_layer(64),]
16
+ # add a subsampling operation
17
+
18
+ # Conv2
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+ # add a subsampling layer operation
25
+
26
+ # Conv3
27
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32
+ model3+=[nn.ReLU(True),]
33
+ model3+=[norm_layer(256),]
34
+ # add a subsampling layer operation
35
+
36
+ # Conv4
37
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40
+ model4+=[nn.ReLU(True),]
41
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42
+ model4+=[nn.ReLU(True),]
43
+ model4+=[norm_layer(512),]
44
+
45
+ # Conv5
46
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47
+ model5+=[nn.ReLU(True),]
48
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49
+ model5+=[nn.ReLU(True),]
50
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51
+ model5+=[nn.ReLU(True),]
52
+ model5+=[norm_layer(512),]
53
+
54
+ # Conv6
55
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56
+ model6+=[nn.ReLU(True),]
57
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58
+ model6+=[nn.ReLU(True),]
59
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60
+ model6+=[nn.ReLU(True),]
61
+ model6+=[norm_layer(512),]
62
+
63
+ # Conv7
64
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65
+ model7+=[nn.ReLU(True),]
66
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67
+ model7+=[nn.ReLU(True),]
68
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model7+=[nn.ReLU(True),]
70
+ model7+=[norm_layer(512),]
71
+
72
+ # Conv7
73
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75
+
76
+ model8=[nn.ReLU(True),]
77
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78
+ model8+=[nn.ReLU(True),]
79
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80
+ model8+=[nn.ReLU(True),]
81
+ model8+=[norm_layer(256),]
82
+
83
+ # Conv9
84
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86
+ # add the two feature maps above
87
+
88
+ model9=[nn.ReLU(True),]
89
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90
+ model9+=[nn.ReLU(True),]
91
+ model9+=[norm_layer(128),]
92
+
93
+ # Conv10
94
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96
+ # add the two feature maps above
97
+
98
+ model10=[nn.ReLU(True),]
99
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
101
+
102
+ # classification output
103
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104
+
105
+ # regression output
106
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107
+ model_out+=[nn.Tanh()]
108
+
109
+ self.model1 = nn.Sequential(*model1)
110
+ self.model2 = nn.Sequential(*model2)
111
+ self.model3 = nn.Sequential(*model3)
112
+ self.model4 = nn.Sequential(*model4)
113
+ self.model5 = nn.Sequential(*model5)
114
+ self.model6 = nn.Sequential(*model6)
115
+ self.model7 = nn.Sequential(*model7)
116
+ self.model8up = nn.Sequential(*model8up)
117
+ self.model8 = nn.Sequential(*model8)
118
+ self.model9up = nn.Sequential(*model9up)
119
+ self.model9 = nn.Sequential(*model9)
120
+ self.model10up = nn.Sequential(*model10up)
121
+ self.model10 = nn.Sequential(*model10)
122
+ self.model3short8 = nn.Sequential(*model3short8)
123
+ self.model2short9 = nn.Sequential(*model2short9)
124
+ self.model1short10 = nn.Sequential(*model1short10)
125
+
126
+ self.model_class = nn.Sequential(*model_class)
127
+ self.model_out = nn.Sequential(*model_out)
128
+
129
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131
+
132
+ def forward(self, input_A, input_B=None, mask_B=None):
133
+ if(input_B is None):
134
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
135
+ if(mask_B is None):
136
+ mask_B = input_A*0
137
+
138
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142
+ conv5_3 = self.model5(conv4_3)
143
+ conv6_3 = self.model6(conv5_3)
144
+ conv7_3 = self.model7(conv6_3)
145
+
146
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147
+ conv8_3 = self.model8(conv8_up)
148
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149
+ conv9_3 = self.model9(conv9_up)
150
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151
+ conv10_2 = self.model10(conv10_up)
152
+ out_reg = self.model_out(conv10_2)
153
+
154
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155
+ conv9_3 = self.model9(conv9_up)
156
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157
+ conv10_2 = self.model10(conv10_up)
158
+ out_reg = self.model_out(conv10_2)
159
+
160
+ return self.unnormalize_ab(out_reg)
161
+
162
+ def siggraph17(pretrained=True):
163
+ model = SIGGRAPHGenerator()
164
+ if(pretrained):
165
+ import torch.utils.model_zoo as model_zoo
166
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167
+ return model
168
+
app/colorizers/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import numpy as np
4
+ from skimage import color
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from IPython import embed
8
+
9
+ def load_img(img_path):
10
+ out_np = np.asarray(Image.open(img_path))
11
+ if(out_np.ndim==2):
12
+ out_np = np.tile(out_np[:,:,None],3)
13
+ return out_np
14
+
15
+ def resize_img(img, HW=(256,256), resample=3):
16
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17
+
18
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19
+ # return original size L and resized L as torch Tensors
20
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21
+
22
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
23
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
24
+
25
+ img_l_orig = img_lab_orig[:,:,0]
26
+ img_l_rs = img_lab_rs[:,:,0]
27
+
28
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30
+
31
+ return (tens_orig_l, tens_rs_l)
32
+
33
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34
+ # tens_orig_l 1 x 1 x H_orig x W_orig
35
+ # out_ab 1 x 2 x H x W
36
+
37
+ HW_orig = tens_orig_l.shape[2:]
38
+ HW = out_ab.shape[2:]
39
+
40
+ # call resize function if needed
41
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43
+ else:
44
+ out_ab_orig = out_ab
45
+
46
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
app/config.py CHANGED
@@ -46,7 +46,8 @@ class Settings(BaseSettings):
46
  "Colorized using GAN-Colorization-Model"
47
  )
48
  INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "fal-ai")
49
- INFERENCE_MODEL: str = os.getenv("INFERENCE_MODEL", "black-forest-labs/FLUX.1-Kontext-dev")
 
50
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
51
  HF_TOKEN: str = os.getenv("HF_TOKEN", "")
52
 
 
46
  "Colorized using GAN-Colorization-Model"
47
  )
48
  INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "fal-ai")
49
+ # Note: black-forest-labs interface not used in main.py - only used in main_sdxl.py
50
+ INFERENCE_MODEL: str = os.getenv("INFERENCE_MODEL", "")
51
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
52
  HF_TOKEN: str = os.getenv("HF_TOKEN", "")
53
 
app/main.py CHANGED
@@ -6,9 +6,11 @@ import uuid
6
  import os
7
  import io
8
  import json
 
9
  from PIL import Image
10
  import torch
11
  from torchvision import transforms
 
12
  from app.database import (
13
  get_database,
14
  log_api_call,
@@ -22,6 +24,17 @@ try:
22
  except ImportError:
23
  firebase_auth = None
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  # -------------------------------------------------
26
  # 🚀 FastAPI App
27
  # -------------------------------------------------
@@ -63,10 +76,10 @@ MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2
63
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
64
  MODEL_FILENAME = "generator.pt"
65
 
66
- print("⬇️ Downloading model...")
67
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
68
 
69
- print("📦 Loading model weights...")
70
  state_dict = torch.load(model_path, map_location="cpu")
71
 
72
  # NOTE: Replace with real model architecture
@@ -75,14 +88,76 @@ state_dict = torch.load(model_path, map_location="cpu")
75
  # model.load_state_dict(state_dict)
76
  # model.eval()
77
 
78
- def colorize_image(img: Image.Image):
79
- """ Dummy colorizer (replace with real model.predict) """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  transform = transforms.ToTensor()
81
  tensor = transform(img.convert("L")).unsqueeze(0)
82
  tensor = tensor.repeat(1, 3, 1, 1)
83
  output_img = transforms.ToPILImage()(tensor.squeeze())
84
  return output_img
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # -------------------------------------------------
87
  # 🗄️ MongoDB Initialization
88
  # -------------------------------------------------
@@ -223,6 +298,7 @@ async def colorize(
223
  user_id: Optional[str] = Form(None),
224
  category_id: Optional[str] = Form(None),
225
  categoryId: Optional[str] = Form(None),
 
226
  ):
227
  import time
228
  start_time = time.time()
@@ -237,6 +313,50 @@ async def colorize(
237
  if not effective_category_id:
238
  effective_category_id = None
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  if not file.content_type.startswith("image/"):
241
  error_msg = "Invalid file type"
242
  log_api_call(
@@ -249,7 +369,7 @@ async def colorize(
249
  # Log failed colorization
250
  log_colorization(
251
  result_id=None,
252
- model_type="gan",
253
  processing_time=None,
254
  user_id=effective_user_id,
255
  ip_address=ip_address,
@@ -260,7 +380,7 @@ async def colorize(
260
 
261
  try:
262
  img = Image.open(io.BytesIO(await file.read()))
263
- output_img = colorize_image(img)
264
 
265
  processing_time = time.time() - start_time
266
 
@@ -276,13 +396,14 @@ async def colorize(
276
  "success": True,
277
  "result_id": result_id_clean,
278
  "download_url": f"{base_url}/results/{result_id}",
279
- "api_download": f"{base_url}/download/{result_id_clean}"
 
280
  }
281
 
282
  # Log to MongoDB (colorization_db -> colorizations)
283
  log_colorization(
284
  result_id=result_id_clean,
285
- model_type="gan",
286
  processing_time=processing_time,
287
  user_id=effective_user_id,
288
  ip_address=ip_address,
@@ -293,7 +414,7 @@ async def colorize(
293
  endpoint="/colorize",
294
  method="POST",
295
  status_code=200,
296
- request_data={"filename": file.filename, "content_type": file.content_type},
297
  response_data=response_data,
298
  user_id=effective_user_id,
299
  ip_address=ip_address
@@ -314,7 +435,7 @@ async def colorize(
314
  # Log failed colorization to colorizations collection
315
  log_colorization(
316
  result_id=None,
317
- model_type="gan",
318
  processing_time=None,
319
  user_id=effective_user_id,
320
  ip_address=ip_address,
 
6
  import os
7
  import io
8
  import json
9
+ import logging
10
  from PIL import Image
11
  import torch
12
  from torchvision import transforms
13
+ import numpy as np
14
  from app.database import (
15
  get_database,
16
  log_api_call,
 
24
  except ImportError:
25
  firebase_auth = None
26
 
27
+ # Import CCO colorizers
28
+ try:
29
+ from app.colorizers import eccv16, siggraph17
30
+ from app.colorizers.util import preprocess_img, postprocess_tens
31
+ CCO_AVAILABLE = True
32
+ except ImportError as e:
33
+ print(f"⚠️ CCO colorizers not available: {e}")
34
+ CCO_AVAILABLE = False
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
  # -------------------------------------------------
39
  # 🚀 FastAPI App
40
  # -------------------------------------------------
 
76
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
77
  MODEL_FILENAME = "generator.pt"
78
 
79
+ print("⬇️ Downloading GAN model...")
80
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
81
 
82
+ print("📦 Loading GAN model weights...")
83
  state_dict = torch.load(model_path, map_location="cpu")
84
 
85
  # NOTE: Replace with real model architecture
 
88
  # model.load_state_dict(state_dict)
89
  # model.eval()
90
 
91
+ # -------------------------------------------------
92
+ # 🧠 Load CCO Colorization Models
93
+ # -------------------------------------------------
94
+ cco_models = {}
95
+ if CCO_AVAILABLE:
96
+ print("📦 Loading CCO models...")
97
+ try:
98
+ cco_models["eccv16"] = eccv16(pretrained=True).eval()
99
+ cco_models["siggraph17"] = siggraph17(pretrained=True).eval()
100
+ print("✅ CCO models loaded successfully!")
101
+ except Exception as e:
102
+ print(f"⚠️ Failed to load CCO models: {e}")
103
+ CCO_AVAILABLE = False
104
+
105
+ def colorize_image_gan(img: Image.Image):
106
+ """ GAN colorizer (dummy implementation - replace with real model.predict) """
107
  transform = transforms.ToTensor()
108
  tensor = transform(img.convert("L")).unsqueeze(0)
109
  tensor = tensor.repeat(1, 3, 1, 1)
110
  output_img = transforms.ToPILImage()(tensor.squeeze())
111
  return output_img
112
 
113
+ def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
114
+ """ CCO colorizer using eccv16 or siggraph17 model """
115
+ if not CCO_AVAILABLE:
116
+ raise ValueError("CCO models are not available")
117
+
118
+ if model_name not in ["eccv16", "siggraph17"]:
119
+ model_name = "eccv16" # Default to eccv16
120
+
121
+ model = cco_models.get(model_name)
122
+ if model is None:
123
+ raise ValueError(f"CCO model '{model_name}' not loaded")
124
+
125
+ # Convert PIL Image to numpy array
126
+ oimg = np.asarray(img)
127
+ if oimg.ndim == 2:
128
+ oimg = np.tile(oimg[:,:,None], 3)
129
+
130
+ # Preprocess image
131
+ (tens_l_orig, tens_l_rs) = preprocess_img(oimg)
132
+
133
+ # Run model inference
134
+ with torch.no_grad():
135
+ out_ab = model(tens_l_rs)
136
+
137
+ # Postprocess output
138
+ output_rgb = postprocess_tens(tens_l_orig, out_ab)
139
+
140
+ # Convert numpy array back to PIL Image
141
+ output_img = Image.fromarray((output_rgb * 255).astype(np.uint8))
142
+ return output_img
143
+
144
+ def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
145
+ """
146
+ Colorize image using specified model
147
+
148
+ Args:
149
+ img: PIL Image to colorize
150
+ model_type: "gan" or "cco"
151
+ cco_model: "eccv16" or "siggraph17" (only used if model_type is "cco")
152
+
153
+ Returns:
154
+ Colorized PIL Image
155
+ """
156
+ if model_type == "cco":
157
+ return colorize_image_cco(img, cco_model)
158
+ else:
159
+ return colorize_image_gan(img)
160
+
161
  # -------------------------------------------------
162
  # 🗄️ MongoDB Initialization
163
  # -------------------------------------------------
 
298
  user_id: Optional[str] = Form(None),
299
  category_id: Optional[str] = Form(None),
300
  categoryId: Optional[str] = Form(None),
301
+ model: Optional[str] = Form("gan"), # New parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17"
302
  ):
303
  import time
304
  start_time = time.time()
 
313
  if not effective_category_id:
314
  effective_category_id = None
315
 
316
+ # Parse model parameter
317
+ model_type = "gan" # Default
318
+ cco_model = "eccv16" # Default for CCO
319
+ model_type_for_log = "gan" # For MongoDB logging
320
+
321
+ if model:
322
+ model = model.strip().lower()
323
+ if model == "cco" or model.startswith("cco-"):
324
+ if not CCO_AVAILABLE:
325
+ error_msg = "CCO models are not available"
326
+ log_api_call(
327
+ endpoint="/colorize",
328
+ method="POST",
329
+ status_code=400,
330
+ error=error_msg,
331
+ ip_address=ip_address
332
+ )
333
+ log_colorization(
334
+ result_id=None,
335
+ model_type="cco",
336
+ processing_time=None,
337
+ user_id=effective_user_id,
338
+ ip_address=ip_address,
339
+ status="failed",
340
+ error=error_msg
341
+ )
342
+ raise HTTPException(status_code=400, detail=error_msg)
343
+
344
+ model_type = "cco"
345
+ if model == "cco-eccv16":
346
+ cco_model = "eccv16"
347
+ model_type_for_log = "cco-eccv16"
348
+ elif model == "cco-siggraph17":
349
+ cco_model = "siggraph17"
350
+ model_type_for_log = "cco-siggraph17"
351
+ else:
352
+ # Default to eccv16 if just "cco" is specified
353
+ cco_model = "eccv16"
354
+ model_type_for_log = "cco-eccv16"
355
+ else:
356
+ # Default to "gan" for any other value
357
+ model_type = "gan"
358
+ model_type_for_log = "gan"
359
+
360
  if not file.content_type.startswith("image/"):
361
  error_msg = "Invalid file type"
362
  log_api_call(
 
369
  # Log failed colorization
370
  log_colorization(
371
  result_id=None,
372
+ model_type=model_type_for_log,
373
  processing_time=None,
374
  user_id=effective_user_id,
375
  ip_address=ip_address,
 
380
 
381
  try:
382
  img = Image.open(io.BytesIO(await file.read()))
383
+ output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
384
 
385
  processing_time = time.time() - start_time
386
 
 
396
  "success": True,
397
  "result_id": result_id_clean,
398
  "download_url": f"{base_url}/results/{result_id}",
399
+ "api_download": f"{base_url}/download/{result_id_clean}",
400
+ "model_used": model_type_for_log
401
  }
402
 
403
  # Log to MongoDB (colorization_db -> colorizations)
404
  log_colorization(
405
  result_id=result_id_clean,
406
+ model_type=model_type_for_log,
407
  processing_time=processing_time,
408
  user_id=effective_user_id,
409
  ip_address=ip_address,
 
414
  endpoint="/colorize",
415
  method="POST",
416
  status_code=200,
417
+ request_data={"filename": file.filename, "content_type": file.content_type, "model": model},
418
  response_data=response_data,
419
  user_id=effective_user_id,
420
  ip_address=ip_address
 
435
  # Log failed colorization to colorizations collection
436
  log_colorization(
437
  result_id=None,
438
+ model_type=model_type_for_log,
439
  processing_time=None,
440
  user_id=effective_user_id,
441
  ip_address=ip_address,
requirements.txt CHANGED
@@ -17,4 +17,5 @@ safetensors
17
  ftfy
18
  httpx
19
  email-validator
20
- pymongo
 
 
17
  ftfy
18
  httpx
19
  email-validator
20
+ pymongo
21
+ scikit-image