@@ -42,7 +42,7 @@ void puff_advantage_row(float* values, float* rewards, float* dones,
4242
4343void vtrace_check (torch::Tensor values, torch::Tensor rewards,
4444 torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
45- int num_steps, int horizon) {
45+ torch::Tensor gammas, int num_steps, int horizon) {
4646
4747 // Validate input tensors
4848 torch::Device device = values.device ();
@@ -56,36 +56,42 @@ void vtrace_check(torch::Tensor values, torch::Tensor rewards,
5656 t.contiguous ();
5757 }
5858 }
59+ // Validate gammas tensor
60+ TORCH_CHECK (gammas.dim () == 1 , " Gammas must be 1D" );
61+ TORCH_CHECK (gammas.size (0 ) == num_steps, " Gammas size must match num_steps" );
62+ TORCH_CHECK (gammas.dtype () == torch::kFloat32 , " Gammas must be float32" );
63+ TORCH_CHECK (gammas.is_contiguous (), " Gammas must be contiguous" );
5964}
6065
6166
6267// [num_steps, horizon]
6368void puff_advantage (float * values, float * rewards, float * dones, float * importance,
64- float * advantages, float gamma , float lambda, float rho_clip, float c_clip,
69+ float * advantages, float * gammas , float lambda, float rho_clip, float c_clip,
6570 int num_steps, const int horizon){
66- for (int offset = 0 ; offset < num_steps*horizon; offset+=horizon) {
71+ for (int row = 0 ; row < num_steps; row++) {
72+ int offset = row * horizon;
6773 puff_advantage_row (values + offset, rewards + offset,
6874 dones + offset, importance + offset, advantages + offset,
69- gamma , lambda, rho_clip, c_clip, horizon
75+ gammas[row] , lambda, rho_clip, c_clip, horizon
7076 );
7177 }
7278}
7379
7480
7581void compute_puff_advantage_cpu (torch::Tensor values, torch::Tensor rewards,
7682 torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
77- double gamma , double lambda, double rho_clip, double c_clip) {
83+ torch::Tensor gammas , double lambda, double rho_clip, double c_clip) {
7884 int num_steps = values.size (0 );
7985 int horizon = values.size (1 );
80- vtrace_check (values, rewards, dones, importance, advantages, num_steps, horizon);
86+ vtrace_check (values, rewards, dones, importance, advantages, gammas, num_steps, horizon);
8187 puff_advantage (values.data_ptr <float >(), rewards.data_ptr <float >(),
8288 dones.data_ptr <float >(), importance.data_ptr <float >(), advantages.data_ptr <float >(),
83- gamma , lambda, rho_clip, c_clip, num_steps, horizon
89+ gammas. data_ptr < float >() , lambda, rho_clip, c_clip, num_steps, horizon
8490 );
8591}
8692
8793TORCH_LIBRARY (pufferlib, m) {
88- m.def (" compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma , float lambda, float rho_clip, float c_clip) -> ()" );
94+ m.def (" compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, Tensor gammas , float lambda, float rho_clip, float c_clip) -> ()" );
8995 }
9096
9197TORCH_LIBRARY_IMPL (pufferlib, CPU, m) {
0 commit comments