Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/neuralnet-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: neuralnet-ci

on:
pull_request:
branches: [ main ]
push:
branches: [ main ]

jobs:
cleanup-before:
uses: ./.github/workflows/_cleanup.yml
with:
when: "before"

test-neuralnet:
needs: cleanup-before
name: test-neuralnet - ubuntu-latest
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.11"

- name: Upgrade pip
run: python -m pip install -U pip

- name: Install pufferlib
run: |
pip install -e .[cpu] --no-cache-dir
env:
TMPDIR: ${{ runner.temp }}/build
PIP_NO_CACHE_DIR: 1

- name: Compile C extensions
run: python setup.py build_ext --inplace --force

- name: Run Forward pass
run: python tests/test_drivenet.py
timeout-minutes: 15
45 changes: 45 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#include "drive.h"
#include "drivenet.h"
#include <Python.h>

#define Env Drive
#define MY_SHARED
#define MY_PUT

static PyObject* test_forward(PyObject* self, PyObject* args, PyObject* kwargs);
#define MY_METHODS {"test_forward", (PyCFunction)test_forward, METH_VARARGS | METH_KEYWORDS, "Test forward pass"}

#include "../env_binding.h"

static int my_put(Env* env, PyObject* args, PyObject* kwargs) {
Expand Down Expand Up @@ -216,3 +223,41 @@ static int my_log(PyObject* dict, Log* log) {
assign_to_dict(dict, "avg_collisions_per_agent", log->avg_collisions_per_agent);
return 0;
}

static PyObject* test_forward(PyObject* self, PyObject* args, PyObject* kwargs) {
PyObject* obs_obj = PyDict_GetItemString(kwargs, "observations");
const char* weights_file = unpack_str(kwargs, "weights_file");
const int dynamics_model = unpack(kwargs, "dynamics_model");

PyArrayObject* obs_array = (PyArrayObject*)obs_obj;
int batch_size = PyArray_DIM(obs_array, 0);
float* observations = (float*)PyArray_DATA(obs_array);

Weights* weights = load_weights(weights_file);
if (!weights) {
PyErr_SetString(PyExc_RuntimeError, "Failed to load weights");
return NULL;
}

DriveNet* net = init_drivenet(weights, batch_size, dynamics_model);

npy_intp action_dims[2] = {batch_size, 2};
npy_intp logit_dims[2] = {batch_size, 20}; // 20 = 7 + 13 (steering + speed logits)

PyObject* actions_array = PyArray_SimpleNew(2, action_dims, NPY_INT32);
PyObject* logits_array = PyArray_SimpleNew(2, logit_dims, NPY_FLOAT32);

int* actions = (int*)PyArray_DATA((PyArrayObject*)actions_array);
float* logits = (float*)PyArray_DATA((PyArrayObject*)logits_array);

forward(net, observations, actions);
memcpy(logits, net->actor->output, batch_size * 20 * sizeof(float));

free_drivenet(net);
free(weights);

PyObject* result = PyTuple_New(2);
PyTuple_SetItem(result, 0, actions_array);
PyTuple_SetItem(result, 1, logits_array);
return result;
}
Binary file modified pufferlib/resources/drive/puffer_drive_weights.bin
Binary file not shown.
Binary file added pufferlib/resources/drive/puffer_drive_weights.pt
Binary file not shown.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def run(self):
for c_ext in c_extensions:
if "drive" in c_ext.name:
c_ext.sources.append("inih-r62/ini.c")
c_ext.include_dirs.append("pufferlib/extensions")
c_ext.extra_compile_args.extend(
[
'-DINI_START_COMMENT_PREFIXES="#"',
Expand Down
70 changes: 70 additions & 0 deletions tests/test_drivenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import sys
import numpy as np
import torch

from pufferlib.ocean.torch import Drive, Recurrent
from pufferlib.ocean import env_creator
from pufferlib.ocean.drive import binding


def test_drivenet(
pt_file="resources/drive/puffer_drive_weights.pt",
bin_file="resources/drive/puffer_drive_weights.bin",
batch_size=4,
seed=42,
):
"""Compare logits from PyTorch and C implementations."""

assert os.path.exists(bin_file), f"{bin_file} not found"
assert os.path.exists(pt_file), f"{pt_file} not found"

env = env_creator("puffer_drive")(num_maps=1, num_agents=batch_size, scenario_length=91)
policy = Drive(env, input_size=64, hidden_size=256)
model = Recurrent(env, policy=policy, input_size=256, hidden_size=256)

state_dict = torch.load(pt_file, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

np.random.seed(seed)
torch.manual_seed(seed)
obs = np.random.randn(batch_size, env.num_obs).astype(np.float32)

# Categorical road type features must be integers 0-6
road_start = 7 + 63 * 7
for i in range(200):
obs[:, road_start + i * 7 + 6] = np.random.randint(0, 7, size=batch_size)

with torch.no_grad():
lstm_state = {
"lstm_h": torch.zeros(1, batch_size, 256),
"lstm_c": torch.zeros(1, batch_size, 256),
}
actions_torch, _ = model.forward(torch.from_numpy(obs), lstm_state)

logits_torch = torch.cat(actions_torch, dim=1).cpu().numpy()

# C forward pass
_, logits_c = binding.test_forward(observations=obs, weights_file=bin_file, dynamics_model=0)

diff = np.abs(logits_torch - logits_c)
max_diff = diff.max()
mean_diff = diff.mean()

print(f"First batch:")
print(f"Logits PyTorch: {logits_torch[0, :10]}")
print(f"Logits C: {logits_c[0, :10]}")
print(f" Diff: {diff[0, :10]}")
print(f" Max difference: {max_diff:.6f}")
print(f" Mean difference: {mean_diff:.6f}")

if max_diff < 1e-2:
return True
else:
return False


if __name__ == "__main__":
success = test_drivenet()
sys.exit(0 if success else 1)
Loading