88import torch
99from torch import nn
1010from torchvision .transforms import functional as TF
11- from pytti .image_models import DifferentiableImage , EMAImage
11+
12+ # from pytti.image_models import DifferentiableImage
13+ from pytti .image_models .ema import EMAImage , EMAParametersDict
1214from PIL import Image
1315from torch .nn import functional as F
1416
@@ -44,8 +46,8 @@ def load_dip(input_depth, num_scales, offset_type, offset_groups, device):
4446 return dip_net
4547
4648
47- # class DeepImagePrior(EMAImage):
48- class DeepImagePrior (DifferentiableImage ):
49+ class DeepImagePrior (EMAImage ):
50+ # class DeepImagePrior(DifferentiableImage):
4951 """
5052 https://github.com/nousr/deep-image-prior/
5153 """
@@ -69,7 +71,14 @@ def __init__(
6971 device = "cuda" ,
7072 ** kwargs ,
7173 ):
72- super ().__init__ (width * scale , height * scale )
74+ # super(super(EMAImage)).__init__()
75+ nn .Module .__init__ (self )
76+ super ().__init__ (
77+ width = width * scale ,
78+ height = height * scale ,
79+ decay = ema_val ,
80+ device = device ,
81+ )
7382 net = load_dip (
7483 input_depth = input_depth ,
7584 num_scales = num_scales ,
@@ -85,20 +94,38 @@ def __init__(
8594 # z = torch.cat(get_non_offset_params(net), get_offset_params(net))
8695 # logger.debug(z.shape)
8796 # super().__init__(width * scale, height * scale, z, ema_val)
88- self .net = net
97+ # self.net = net
8998 # self.tensor = self.net.params()
9099 self .output_axes = ("n" , "s" , "y" , "x" )
91100 self .scale = scale
92101 self .device = device
93102
94- self ._net_input = torch .randn ([1 , input_depth , width , height ], device = device )
103+ # self._net_input = torch.randn([1, input_depth, width, height], device=device)
95104
96105 self .lr = lr
97106 self .offset_lr_fac = offset_lr_fac
98107 # self._params = [
99108 # {'params': get_non_offset_params(net), 'lr': lr},
100109 # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac}
101110 # ]
111+ # z = {
112+ # 'non_offset':get_non_offset_params(net),
113+ # 'offset':get_offset_params(net),
114+ # }
115+ self .net = net
116+ self ._net_input = torch .randn ([1 , input_depth , width , height ], device = device )
117+
118+ self .image_representation_parameters = EMAParametersDict (
119+ z = self .net , decay = ema_val , device = device
120+ )
121+
122+ # super().__init__(
123+ # width = width * scale,
124+ # height = height * scale,
125+ # tensor = z,
126+ # decay = ema_val,
127+ # device=device,
128+ # )
102129
103130 # def get_image_tensor(self):
104131 def decode_tensor (self ):
@@ -129,17 +156,34 @@ def get_latent_tensor(self, detach=False):
129156 return params
130157
131158 def clone (self ):
132- # dummy = super().__init__ (*self.image_shape)
159+ # dummy = VQGANImage (*self.image_shape)
133160 # with torch.no_grad():
134- # #dummy.tensor.set_(self.tensor.clone())
135- # dummy.net.copy_(self.net)
136- # dummy.accum.set_(self.accum.clone())
137- # dummy.biased.set_(self.biased.clone())
138- # dummy.average.set_(self.average.clone())
139- # dummy.decay = self.decay
140- dummy = deepcopy (self )
161+ # dummy.representation_parameters.set_(self.representation_parameters.clone())
162+ # dummy.accum.set_(self.accum.clone())
163+ # dummy.biased.set_(self.biased.clone())
164+ # dummy.average.set_(self.average.clone())
165+ # dummy.decay = self.decay
166+ # return dummy
167+ dummy = DeepImagePrior (* self .image_shape )
168+ with torch .no_grad ():
169+ # dummy.representation_parameters.set_(self.representation_parameters.clone())
170+ dummy .image_representation_parameters .set_ (
171+ self .image_representation_parameters .clone ()
172+ )
141173 return dummy
142174
175+ # def clone(self):
176+ # # dummy = super().__init__(*self.image_shape)
177+ # # with torch.no_grad():
178+ # # #dummy.tensor.set_(self.tensor.clone())
179+ # # dummy.net.copy_(self.net)
180+ # # dummy.accum.set_(self.accum.clone())
181+ # # dummy.biased.set_(self.biased.clone())
182+ # # dummy.average.set_(self.average.clone())
183+ # # dummy.decay = self.decay
184+ # dummy = deepcopy(self)
185+ # return dummy
186+
143187 def encode_random (self ):
144188 pass
145189
0 commit comments