|
24 | 24 | @test length(t) == 0
|
25 | 25 | end
|
26 | 26 |
|
27 |
| -@testset "CircularArraySARTSTraces" begin |
| 27 | +@testset "CircularArraySARTSATraces" begin |
28 | 28 | t = CircularArraySARTSATraces(;
|
29 | 29 | capacity=3,
|
30 | 30 | state=Float32 => (2, 3),
|
|
35 | 35 |
|
36 | 36 | @test t isa CircularArraySARTSATraces
|
37 | 37 |
|
38 |
| - push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu) |
| 38 | + push!(t, (state=ones(Float32, 2, 3),)) |
| 39 | + push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) |
39 | 40 | @test length(t) == 0
|
40 | 41 |
|
41 | 42 | push!(t, (reward=1.0f0, terminal=false) |> gpu)
|
42 |
| - @test length(t) == 0 # next_state and next_action is still missing |
| 43 | + @test length(t) == 0 # next_action is still missing |
43 | 44 |
|
44 |
| - push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu) |
| 45 | + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu) |
45 | 46 | @test length(t) == 1
|
46 | 47 |
|
47 | 48 | # this will trigger the scalar indexing of CuArray
|
|
55 | 56 | )
|
56 | 57 |
|
57 | 58 | push!(t, (reward=2.0f0, terminal=false))
|
58 |
| - push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu) |
| 59 | + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu) |
59 | 60 |
|
60 | 61 | @test length(t) == 2
|
61 | 62 |
|
62 | 63 | push!(t, (reward=3.0f0, terminal=false))
|
63 |
| - push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu) |
| 64 | + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu) |
64 | 65 |
|
65 | 66 | @test length(t) == 3
|
66 | 67 |
|
67 | 68 | push!(t, (reward=4.0f0, terminal=false))
|
68 |
| - push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu) |
| 69 | + push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu) |
| 70 | + push!(t, (reward=5.0f0, terminal=false)) |
69 | 71 |
|
70 | 72 | @test length(t) == 3
|
71 | 73 |
|
|
127 | 129 | @test t isa CircularArraySLARTTraces
|
128 | 130 | end
|
129 | 131 |
|
130 |
| -@testset "CircularPrioritizedTraces" begin |
| 132 | +@testset "CircularPrioritizedTraces-SARTS" begin |
131 | 133 | t = CircularPrioritizedTraces(
|
132 |
| - CircularArraySARTSATraces(; |
| 134 | + CircularArraySARTSTraces(; |
133 | 135 | capacity=3
|
134 | 136 | ),
|
135 | 137 | default_priority=1.0f0
|
|
160 | 162 |
|
161 | 163 | #EpisodesBuffer
|
162 | 164 | t = CircularPrioritizedTraces(
|
163 |
| - CircularArraySARTSATraces(; |
| 165 | + CircularArraySARTSTraces(; |
164 | 166 | capacity=10
|
165 | 167 | ),
|
166 | 168 | default_priority=1.0f0
|
|
186 | 188 | eb[:priority, [1, 2]] = [0, 0]
|
187 | 189 | @test eb[:priority] == [zeros(2);ones(8)]
|
188 | 190 | end
|
| 191 | + |
| 192 | +@testset "CircularPrioritizedTraces-SARTSA" begin |
| 193 | + t = CircularPrioritizedTraces( |
| 194 | + CircularArraySARTSATraces(; |
| 195 | + capacity=3 |
| 196 | + ), |
| 197 | + default_priority=1.0f0 |
| 198 | + ) |
| 199 | + |
| 200 | + push!(t, (state=0, action=0)) |
| 201 | + |
| 202 | + for i in 1:5 |
| 203 | + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) |
| 204 | + end |
| 205 | + |
| 206 | + @test length(t) == 3 |
| 207 | + |
| 208 | + s = BatchSampler(5) |
| 209 | + |
| 210 | + b = sample(s, t) |
| 211 | + |
| 212 | + t[:priority, [1, 2]] = [0, 0] |
| 213 | + |
| 214 | + # shouldn't be changed since [1,2] are old keys |
| 215 | + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] |
| 216 | + |
| 217 | + t[:priority, [3, 4, 5]] = [0, 1, 0] |
| 218 | + |
| 219 | + b = sample(s, t) |
| 220 | + |
| 221 | + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 |
| 222 | + |
| 223 | + #EpisodesBuffer |
| 224 | + t = CircularPrioritizedTraces( |
| 225 | + CircularArraySARTSATraces(; |
| 226 | + capacity=10 |
| 227 | + ), |
| 228 | + default_priority=1.0f0 |
| 229 | + ) |
| 230 | + |
| 231 | + eb = EpisodesBuffer(t) |
| 232 | + push!(eb, (state = 1,)) |
| 233 | + for i = 1:5 |
| 234 | + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) |
| 235 | + end |
| 236 | + push!(eb, PartialNamedTuple((action = 6,))) |
| 237 | + push!(eb, (state = 7,)) |
| 238 | + for (j,i) = enumerate(8:11) |
| 239 | + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) |
| 240 | + end |
| 241 | + push!(eb, PartialNamedTuple((action=12,))) |
| 242 | + s = BatchSampler(1000) |
| 243 | + b = sample(s, eb) |
| 244 | + cm = counter(b[:state]) |
| 245 | + @test !haskey(cm, 6) |
| 246 | + @test !haskey(cm, 11) |
| 247 | + @test all(in(keys(cm)), [1:5;7:10]) |
| 248 | +end |
0 commit comments