Skip to content

Weird Side Effects of loadparams! #1979

Open
@FelixBenning

Description

@FelixBenning

I am posting a picture of the pluto notebook for better understanding/readability (what are the outputs etc). Below you can find copyable code.

Screenshot 2022-05-30 at 16 29 49

Copyable Code
begin
	using Flux:Flux
	using MLDatasets: MNIST
	using LinearAlgebra: norm, LinearAlgebra
end

MNIST.download(i_accept_the_terms_of_use=true)

begin
	x_train, y_train = MNIST.traindata()
	x_train = Float32.(x_train)
	y_train_oh = Flux.onehotbatch(y_train, 0:9)
	size(x_train), size(y_train), size(y_train_oh)
end

function carth_grid(dim=3; start=0, stop=5, length=11)
	Iterators.product(fill(range(start, stop, length=length), dim)...)
end


function toy_model()
	return Flux.Chain(
		Flux.flatten,
		Flux.Dense(foldl(*, size(x_train[:,:,1])), 10, bias=false)
	)
end

begin
	grid = carth_grid(2, start=0, stop=1, length=2)
	collect(grid)
end

begin
	model = toy_model()
	origin = Flux.params(model)
	directions = map(1:length(size(grid))) do _
		dir = [x for x in Flux.params(toy_model())]
		return dir/norm(dir)
	end
	r_grid = map(grid) do coords
		ps = origin .+ sum(zip(coords, directions)) do (coeff, dir)
			coeff * dir
		end
		# Flux.loadparams!(model, deepcopy(ps))
		return ps
	end
	map(pairs(r_grid)) do (x,y)
		norm(x-y)
	end
end

Okay so what is happening? I define a generator of a model toy_model, to essentially generate different parameter initializations. Because apparently that is faster than using loadparams #1764 and I can use the default initializer to get a more realistic distribution over initializations. Then I sample two prameter vectors from this distribution and normalize them. Since they are random with high dimension, they are very likely to be almost orthogonal. So due to normalization we get orthonormal vectors. So when I create the grid using the coordinates from a carthesian grid, I am essentially doing an orthonormal basis change.

Okay, so far so good. Now when I determine the distances of all the points on the new grid, the end result should be the same as before the basis change. Now the longest distance is between the points (0,0) and (1,1). Which is the squareroot of 2. So ~1.4. And that is in fact the largets value of the output. Nice! Everything works as intended.

But if I comment in the Flux.loadparams! my new output makes no sense

[1.0, 1.41621, 2.83242, 1.0, 2.23859, 1.41621]

This implies that loadparams! somehow modifies ps even through the deepcopy?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions