Skip to content

Conversation

@csestili
Copy link
Contributor

@csestili csestili commented Feb 15, 2024

IMPORTANT: DO NOT MERGE into main branch!

If you want to use this code:

git pull
git checkout origin/uncertain-loss

To enable the feature, set model_uncertainty = true in config-base.yaml.

This is an experimental branch where I implemented the loss function from What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? (Kendall & Gal, 2017). After implementing, I tried training on a dev dataset, and found that the models trained with this method were all worse than just using a standard CNN. I struggled to implement this feature in a way that would not make the codebase significantly more complex. Due to certain limitations in the way Keras handles custom loss functions, the result is still fairly complex, and would not be easy for anyone to maintain for years down the line. Given the limited effectiveness of this method, and the added code complexity, I have decided not to merge this feature into the main branch. I strongly encourage future developers to avoid merging this feature without taking these considerations into account.

This feature was tested on a binary classification problem (SST open chromatin in mouse). I did not test it on a regression problem, which could yield more promising results.

Briefly, the idea is:

  • We change the neural network architecture to output both an estimated value (the standard output of the network) and an estimated variance (sigma) that represents the network's uncertainty about this output.
  • In practice, we implement this by adding (appending) a unit to the end of the existing output units.
  • We introduce a custom loss function such that, when sigma is large, the output of the network is allowed to differ widely from the target value, but when sigma is small, the network is penalized harshly for deviating from the target value.
  • The relevant math is as follows (from models.py):
def get_sigma_loss(orig_loss_fn):
	def sigma_loss_fn(y_true, y_pred):
		"""Sigma loss: Loss function that accounts for predicted variance (sigma)
		Equation 8 from https://arxiv.org/pdf/1703.04977.pdf

		When log_variance is:
			< 0: sigma loss is "more strict" than original loss
			= 0: sigma loss is equal to original loss
			> 0: sigma loss is "more permissive" than original loss

		Inputs:
			y_true, shape [num_examples, 1]
			y_pred, shape [num_examples, num_classes + 1]
				each row of y_pred is assumed to be:
					(pr_class_0, ..., pr_class_N, log_variance) for classification
					(pred_target, log_variance) for regression
		"""
		loss_fn = orig_loss_fn
		try:
			loss_fn = keras.losses.get(orig_loss_fn)
		except Exception as e:
			pass

		# evaluate the original loss function on the original output of the network (all except last unit)
		orig_loss = loss_fn(y_true, y_pred[:, :-1])

		# get uncertainty estimate
		# last unit represents s = log(sigma^2) (equation 8)
		# where sigma^2 is the estimated variance associated with this prediction
		log_variance = y_pred[:, -1]
		sigma_loss = tf.math.exp(-log_variance) * orig_loss + log_variance

		return sigma_loss
	return sigma_loss_fn

The significant code complexity comes from the choice to append sigma at the end of the output units, so that there is still only one output head for the network. For example, a regression model would have outputs (estimated y value, sigma), while a binary classification model would have (P class 0, P class 1, sigma). This choice is ad-hoc and a bit unnatural, and I do not want to set a precedent that we can just add output units to support another feature, which would get confusing if we want to use that other feature in combination with the uncertainty modeling.

However, this was the only implementation that I could come up with inside of Keras. I attempted to add another output head of the model, which would output only sigma. But because of the way Keras loss functions are implemented, loss functions operate on each output head individually, so the math above would have become impossible to implement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants