diff --git a/pix2pix.py b/pix2pix.py index d766cb78d..ae36bfe3d 100644 --- a/pix2pix.py +++ b/pix2pix.py @@ -39,6 +39,8 @@ parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally") parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally") parser.set_defaults(flip=True) +parser.add_argument("--png16bits", dest="png16bits", action="store_true", help="use png 16 bits images encoder and decoders") +parser.set_defaults(png16bits=False) parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam") parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam") parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient") @@ -258,7 +260,10 @@ def get_name(path): path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train") reader = tf.WholeFileReader() paths, contents = reader.read(path_queue) - raw_input = decode(contents) + if a.png16bits: + raw_input = decode(contents,dtype=tf.uint16) + else: + raw_input = decode(contents) raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") @@ -586,7 +591,11 @@ def main(): with tf.variable_scope("generator"): batch_output = deprocess(create_generator(preprocess(batch_input), 3)) - output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] + if a.png16bits: + output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint16, saturate=True)[0] + else: + output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] + if a.output_filetype == "png": output_data = tf.image.encode_png(output_image) elif a.output_filetype == "jpeg": @@ -656,7 +665,10 @@ def convert(image): size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))] image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC) - return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) + if a.png16bits: + return tf.image.convert_image_dtype(image, dtype=tf.uint16, saturate=True) + else: + return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_inputs"): @@ -677,14 +689,25 @@ def convert(image): } # summaries - with tf.name_scope("inputs_summary"): - tf.summary.image("inputs", converted_inputs) + if not a.png16bits: + with tf.name_scope("inputs_summary"): + tf.summary.image("inputs", converted_inputs) - with tf.name_scope("targets_summary"): - tf.summary.image("targets", converted_targets) + with tf.name_scope("targets_summary"): + tf.summary.image("targets", converted_targets) - with tf.name_scope("outputs_summary"): - tf.summary.image("outputs", converted_outputs) + with tf.name_scope("outputs_summary"): + tf.summary.image("outputs", converted_outputs) + + else: + with tf.name_scope("inputs_summary"): + tf.summary.image("inputs", tf.image.convert_image_dtype(converted_inputs,dtype=tf.uint8)) + + with tf.name_scope("targets_summary"): + tf.summary.image("targets", tf.image.convert_image_dtype(converted_targets,dtype=tf.uint8)) + + with tf.name_scope("outputs_summary"): + tf.summary.image("outputs", tf.image.convert_image_dtype(converted_outputs,dtype=tf.uint8)) with tf.name_scope("predict_real_summary"): tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) @@ -692,6 +715,7 @@ def convert(image): with tf.name_scope("predict_fake_summary"): tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) + tf.summary.scalar("discriminator_loss", model.discrim_loss) tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)