-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinear.m
68 lines (55 loc) · 2.1 KB
/
Linear.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
classdef Linear < Node
properties
weights(:, 1) double
end
methods
function obj = Linear(parents, weights, varargin)
% Make it easier to specify no parents by using []
if ~numel(parents)
parents = Node.empty();
end
obj = obj@Node(parents, varargin{:});
obj.weights = weights;
assert(numel(obj.weights) == numel(obj.parents) + 1);
end
end
methods (Access = protected)
function value = evalElement(obj)
value = obj.weights' * [1 obj.parents.evalImpl()]';
end
function computeScaleElement(obj)
obj.scale = [obj.weights(1) obj.weights(1)];
for i = 1:numel(obj.parents)
s = obj.weights(1 + i) * obj.parents(i).scale;
obj.scale(1) = obj.scale(1) + min(s);
obj.scale(2) = obj.scale(2) + max(s);
end
end
function height = getCompiledHeight(~)
height = 0;
end
function simplifyElement(obj)
i = 1;
while i <= numel(obj.parents)
p = obj.parents(i);
if isa(p, 'Linear')
obj.weights(1) = obj.weights(1) + obj.weights(1 + i) * p.weights(1);
for j = 1:numel(p.parents)
if any(obj.parents == p.parents(j))
k = find(obj.parents == p.parents(j), 1);
obj.weights(1 + k) = obj.weights(1 + k) + ...
obj.weights(1 + i) * p.weights(1 + j);
else
obj.parents(end + 1) = p.parents(j);
obj.weights(end + 1) = obj.weights(1 + i) * p.weights(1 + j);
end
end
obj.weights(1 + i) = [];
obj.parents(i) = [];
else
i = i + 1;
end
end
end
end
end