diff --git a/src/qseek/ext/stack.c b/src/qseek/ext/stack.c new file mode 100644 index 00000000..908c86ff --- /dev/null +++ b/src/qseek/ext/stack.c @@ -0,0 +1,478 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include +#include +#include +#include + +// Structure definitions equivalent to Mojo's structs +typedef struct { + int32_t *shifts; + float *weights; +} Node; + +typedef struct { + int32_t *shifts; + float *weights; + float *stack; +} NodeWithStack; + +typedef struct { + float *data; + Py_ssize_t size; + int32_t offset; +} Trace; + +// Function to get thread count +static inline int get_thread_count(int n_threads) { + if (n_threads <= 0) { + return omp_get_max_threads(); + } + return n_threads; +} + +static inline Py_ssize_t imax(Py_ssize_t a, Py_ssize_t b) { + return a > b ? a : b; +} +static inline Py_ssize_t imin(Py_ssize_t a, Py_ssize_t b) { + return a < b ? a : b; +} + +// Function to check NumPy array dtype +static inline int check_array_dtype(PyArrayObject *arr, int expected_type) { + if (PyArray_TYPE(arr) != expected_type) { + PyErr_Format(PyExc_TypeError, "Input array must be of type %s", + expected_type == NPY_FLOAT32 ? "float32" : "unknown"); + return 0; + } + return 1; +} + +// Prepare function equivalent to Mojo's prepare +static PyObject *prepare(PyObject *traces, PyObject *offsets, PyObject *shifts, + PyObject *weights, Trace **traces_list, + Node **nodes_list, int32_t *min_shift, + int32_t *max_shift) { + Py_ssize_t n_traces = PyList_Size(traces); + PyArrayObject *shifts_arr = + (PyArrayObject *)PyArray_ContiguousFromObject(shifts, NPY_INT32, 2, 2); + PyArrayObject *weights_arr = + (PyArrayObject *)PyArray_ContiguousFromObject(weights, NPY_FLOAT32, 2, 2); + PyArrayObject *offsets_arr = + (PyArrayObject *)PyArray_ContiguousFromObject(offsets, NPY_INT32, 1, 1); + + if (!shifts_arr || !weights_arr || !offsets_arr) { + Py_XDECREF(shifts_arr); + Py_XDECREF(weights_arr); + Py_XDECREF(offsets_arr); + return NULL; + } + + if (n_traces == 0) { + PyErr_SetString(PyExc_ValueError, + "Input traces must have positive dimensions"); + goto cleanup; + } + + npy_intp *shifts_shape = PyArray_SHAPE(shifts_arr); + Py_ssize_t n_nodes = shifts_shape[0]; + if (n_nodes == 0) { + PyErr_SetString(PyExc_ValueError, + "Input arrays must have positive dimensions"); + goto cleanup; + } + + if (!check_array_dtype(weights_arr, NPY_FLOAT32) || + !check_array_dtype(offsets_arr, NPY_INT32) || + !check_array_dtype(shifts_arr, NPY_INT32)) { + goto cleanup; + } + + if (shifts_shape[0] != PyArray_SHAPE(weights_arr)[0] || + shifts_shape[1] != PyArray_SHAPE(weights_arr)[1]) { + PyErr_SetString(PyExc_ValueError, + "Shifts and weights must have the same shape"); + goto cleanup; + } + if (n_traces != PyArray_SHAPE(offsets_arr)[0]) { + PyErr_SetString(PyExc_ValueError, + "Number of arrays must match number of offsets"); + goto cleanup; + } + if (shifts_shape[1] != n_traces) { + PyErr_SetString(PyExc_ValueError, + "Shifts must have the same number of columns as traces"); + goto cleanup; + } + + int32_t *offsets_data = (int32_t *)PyArray_DATA(offsets_arr); + int32_t *shifts_data = (int32_t *)PyArray_DATA(shifts_arr); + float *weights_data = (float *)PyArray_DATA(weights_arr); + + *traces_list = (Trace *)malloc(n_traces * sizeof(Trace)); + *nodes_list = (Node *)malloc(n_nodes * sizeof(Node)); + if (!*traces_list || !*nodes_list) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory"); + goto cleanup; + } + + for (Py_ssize_t i = 0; i < n_traces; i++) { + PyArrayObject *trace = (PyArrayObject *)PyArray_ContiguousFromObject( + PyList_GetItem(traces, i), NPY_FLOAT32, 1, 1); + if (!trace) + goto cleanup_traces; + if (!check_array_dtype(trace, NPY_FLOAT32)) { + Py_DECREF(trace); + goto cleanup_traces; + } + (*traces_list)[i].data = (float *)PyArray_DATA(trace); + (*traces_list)[i].size = PyArray_SIZE(trace); + (*traces_list)[i].offset = offsets_data[i]; + Py_DECREF(trace); // We keep the data pointer, but release the array object + } + + for (Py_ssize_t i = 0; i < n_nodes; i++) { + (*nodes_list)[i].shifts = shifts_data + i * n_traces; + (*nodes_list)[i].weights = weights_data + i * n_traces; + } + + *min_shift = INT32_MAX; + *max_shift = INT32_MIN; + for (Py_ssize_t i = 0; i < n_nodes; i++) { + Node node = (*nodes_list)[i]; + for (Py_ssize_t j = 0; j < n_traces; j++) { + int32_t idx_begin = (*traces_list)[j].offset + node.shifts[j]; + int32_t idx_end = idx_begin + (*traces_list)[j].size; + *min_shift = (*min_shift < idx_begin) ? *min_shift : idx_begin; + *max_shift = (*max_shift > idx_end) ? *max_shift : idx_end; + } + } + + Py_DECREF(shifts_arr); + Py_DECREF(weights_arr); + Py_DECREF(offsets_arr); + return traces; + +cleanup_traces: + free(*traces_list); + free(*nodes_list); +cleanup: + Py_XDECREF(shifts_arr); + Py_XDECREF(weights_arr); + Py_XDECREF(offsets_arr); + return NULL; +} + +static PyObject *stack_traces(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyObject *traces, *offsets, *shifts, *weights, *result; + result = Py_None; // Default to None if not provided + int n_threads = 1; + int result_samples = 0; + + static char *kwlist[] = {"traces", "offsets", "shifts", "weights", + "result", "result_samples", "n_threads", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO|Oii", kwlist, &traces, + &offsets, &shifts, &weights, &result, + &result_samples, &n_threads)) { + return NULL; + } + + Trace *traces_list; + Node *nodes_list; + int32_t min_shift, max_shift; + if (!prepare(traces, offsets, shifts, weights, &traces_list, &nodes_list, + &min_shift, &max_shift)) { + return NULL; + } + + Py_ssize_t n_traces = PyList_Size(traces); + Py_ssize_t n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0]; + Py_ssize_t result_length = max_shift - min_shift; + if (result_samples > 0) { + result_length = (Py_ssize_t)result_samples; + min_shift = 0; + } + + PyObject *result_arr; + if (result == Py_None) { + npy_intp dims[2] = {n_nodes, result_length}; + result_arr = PyArray_SimpleNew(2, dims, NPY_FLOAT32); + if (!result_arr) { + free(traces_list); + free(nodes_list); + return NULL; + } + } else { + result_arr = PyArray_ContiguousFromObject(result, NPY_FLOAT32, 2, 2); + if (!result_arr || + !check_array_dtype((PyArrayObject *)result_arr, NPY_FLOAT32)) { + free(traces_list); + free(nodes_list); + return NULL; + } + npy_intp *shape = PyArray_SHAPE((PyArrayObject *)result_arr); + if (shape[0] != n_nodes || shape[1] != result_length) { + PyErr_SetString(PyExc_ValueError, + "Result array must have shape (n_nodes, length_out)"); + Py_DECREF(result_arr); + free(traces_list); + free(nodes_list); + return NULL; + } + } + + float *result_data = (float *)PyArray_DATA((PyArrayObject *)result_arr); + NodeWithStack *node_list = + (NodeWithStack *)malloc(n_nodes * sizeof(NodeWithStack)); + if (!node_list) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate node_list"); + Py_DECREF(result_arr); + free(traces_list); + free(nodes_list); + return NULL; + } + + for (Py_ssize_t i = 0; i < n_nodes; i++) { + node_list[i].shifts = nodes_list[i].shifts; + node_list[i].weights = nodes_list[i].weights; + node_list[i].stack = result_data + i * result_length; + } + + Py_BEGIN_ALLOW_THREADS; +#pragma omp parallel for num_threads(get_thread_count(n_threads)) + for (Py_ssize_t i_node = 0; i_node < n_nodes; i_node++) { + NodeWithStack node = node_list[i_node]; + for (Py_ssize_t i_trace = 0; i_trace < n_traces; i_trace++) { + float weight = node.weights[i_trace]; + if (weight == 0.0f) + continue; + + Trace trace = traces_list[i_trace]; + int32_t trace_shift = trace.offset + node.shifts[i_trace]; + int32_t base_idx = trace_shift - min_shift; + Py_ssize_t stack_nsamples = imin(result_length - base_idx, trace.size); + + Py_ssize_t i; + __m256 weight_vec = _mm256_set1_ps(weight); + + for (i = imax(0, min_shift - trace_shift); + i < stack_nsamples - (stack_nsamples % 8); i += 8) { + Py_ssize_t i_res = base_idx + i; + __m256 trace_vec = _mm256_loadu_ps(&trace.data[i]); + __m256 stack_vec = _mm256_loadu_ps(&node.stack[i_res]); + stack_vec = _mm256_fmadd_ps(trace_vec, weight_vec, stack_vec); + _mm256_storeu_ps(&node.stack[i_res], stack_vec); + } + for (; i < stack_nsamples; i++) { + Py_ssize_t i_res = base_idx + i; + node.stack[i_res] += trace.data[i] * weight; + } + } + } + Py_END_ALLOW_THREADS; + + free(node_list); + free(traces_list); + free(nodes_list); + + PyObject *ret = PyTuple_New(2); + PyTuple_SetItem(ret, 0, result_arr); + PyTuple_SetItem(ret, 1, PyLong_FromLong(min_shift)); + return ret; +} + +// stack_and_reduce function +static PyObject *stack_and_reduce(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyObject *traces, *offsets, *shifts, *weights; + int n_threads = 1; + + static char *kwlist[] = {"traces", "offsets", "shifts", + "weights", "n_threads", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO|i", kwlist, &traces, + &offsets, &shifts, &weights, &n_threads)) { + return NULL; + } + + Trace *traces_list; + Node *nodes_list; + int32_t min_shift, max_shift; + if (!prepare(traces, offsets, shifts, weights, &traces_list, &nodes_list, + &min_shift, &max_shift)) { + return NULL; + } + + Py_ssize_t n_traces = PyList_Size(traces); + Py_ssize_t n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0]; + Py_ssize_t result_length = max_shift - min_shift; + + npy_intp dims = result_length; + PyObject *node_max = PyArray_SimpleNew(1, &dims, NPY_FLOAT32); + PyObject *node_argmax = PyArray_SimpleNew(1, &dims, NPY_UINTP); + if (!node_max || !node_argmax) { + Py_XDECREF(node_max); + Py_XDECREF(node_argmax); + free(traces_list); + free(nodes_list); + return NULL; + } + + float *node_max_data = (float *)PyArray_DATA((PyArrayObject *)node_max); + uint64_t *node_argmax_data = + (uint64_t *)PyArray_DATA((PyArrayObject *)node_argmax); + for (Py_ssize_t i = 0; i < result_length; i++) { + node_max_data[i] = -NPY_INFINITYF; + node_argmax_data[i] = 0; + } + Py_BEGIN_ALLOW_THREADS; +#pragma omp parallel num_threads(get_thread_count(n_threads)) + { + Py_ssize_t tile_start_idx = + omp_get_thread_num() * result_length / get_thread_count(n_threads); + Py_ssize_t tile_end_idx = (omp_get_thread_num() + 1) * result_length / + get_thread_count(n_threads); + Py_ssize_t tile_size = tile_end_idx - tile_start_idx; + float *tile_node_stack = + (float *)aligned_alloc(32, tile_size * sizeof(float)); + + for (Py_ssize_t i_node = 0; i_node < n_nodes; i_node++) { + Node node = nodes_list[i_node]; + memset(tile_node_stack, 0, tile_size * sizeof(float)); + + for (Py_ssize_t i_trace = 0; i_trace < n_traces; i_trace++) { + float weight = node.weights[i_trace]; + if (weight == 0.0f) + continue; + Trace trace = traces_list[i_trace]; + int32_t trace_shift = trace.offset + node.shifts[i_trace]; + int32_t base_idx = trace_shift - min_shift; + Py_ssize_t tile_base_idx = imax(0, base_idx - tile_start_idx); + Py_ssize_t trace_start_idx = imax(0, tile_start_idx - base_idx); + Py_ssize_t trace_end_idx = imax(0, tile_end_idx - base_idx); + trace_start_idx = imin(trace_start_idx, trace.size); + trace_end_idx = imin(trace_end_idx, trace.size); + Py_ssize_t n_samples = trace_end_idx - trace_start_idx; + + Py_ssize_t i; + __m256 weight_vec = _mm256_set1_ps(weight); + + for (i = 0; i < n_samples - (n_samples % 8); i += 8) { + Py_ssize_t i_res = tile_base_idx + i; + __m256 trace_vec = _mm256_loadu_ps(&trace.data[trace_start_idx + i]); + __m256 stack_vec = _mm256_load_ps(&tile_node_stack[i_res]); + stack_vec = _mm256_fmadd_ps(trace_vec, weight_vec, stack_vec); + _mm256_storeu_ps(&tile_node_stack[i_res], stack_vec); + } + for (; i < n_samples; i++) { + Py_ssize_t i_res = tile_base_idx + i; + tile_node_stack[i_res] += trace.data[trace_start_idx + i] * weight; + } + } + + for (Py_ssize_t i = 0; i < tile_size; i++) { + Py_ssize_t tile_idx = tile_start_idx + i; + float node_val = tile_node_stack[i]; + if (node_val > node_max_data[tile_idx]) { + node_max_data[tile_idx] = node_val; + node_argmax_data[tile_idx] = i_node; + } + } + } + free(tile_node_stack); + } + Py_END_ALLOW_THREADS; + + free(traces_list); + free(nodes_list); + + PyObject *ret = PyTuple_New(3); + PyTuple_SetItem(ret, 0, node_max); + PyTuple_SetItem(ret, 1, node_argmax); + PyTuple_SetItem(ret, 2, PyLong_FromLong(min_shift)); + return ret; +} + +// stack_snapshot function +static PyObject *stack_snapshot(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyObject *traces, *offsets, *shifts, *weights; + int32_t index; + + static char *kwlist[] = {"traces", "offsets", "shifts", + "weights", "index", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOOi", kwlist, &traces, + &offsets, &shifts, &weights, &index)) { + return NULL; + } + + Trace *traces_list; + Node *nodes_list; + int32_t min_shift, max_shift; + if (!prepare(traces, offsets, shifts, weights, &traces_list, &nodes_list, + &min_shift, &max_shift)) { + return NULL; + } + + Py_ssize_t n_traces = PyList_Size(traces); + Py_ssize_t n_nodes = PyArray_SHAPE((PyArrayObject *)shifts)[0]; + Py_ssize_t result_length = max_shift - min_shift; + + if (index >= result_length || index < 0) { + PyErr_Format(PyExc_ValueError, "Snapshot index out of bounds: %d", index); + free(traces_list); + free(nodes_list); + return NULL; + } + + npy_intp dims = n_nodes; + PyObject *result = PyArray_SimpleNew(1, &dims, NPY_FLOAT32); + if (!result) { + free(traces_list); + free(nodes_list); + return NULL; + } + + float *result_data = (float *)PyArray_DATA((PyArrayObject *)result); + memset(result_data, 0, n_nodes * sizeof(float)); + + for (Py_ssize_t i_node = 0; i_node < n_nodes; i_node++) { + Node node = nodes_list[i_node]; + for (Py_ssize_t i_trace = 0; i_trace < n_traces; i_trace++) { + float weight = node.weights[i_trace]; + Trace trace = traces_list[i_trace]; + int32_t trace_shift = trace.offset + node.shifts[i_trace]; + int32_t base_idx = trace_shift - min_shift; + int32_t trace_sample = index - base_idx; + if (0 <= trace_sample && trace_sample < trace.size) { + result_data[i_node] += trace.data[trace_sample] * weight; + } + } + } + + free(traces_list); + free(nodes_list); + return result; +} + +// Method definitions +static PyMethodDef StackTracesMethods[] = { + {"stack_traces", (PyCFunction)(void (*)(void))stack_traces, + METH_VARARGS | METH_KEYWORDS, ""}, + {"stack_and_reduce", (PyCFunction)(void (*)(void))stack_and_reduce, + METH_VARARGS | METH_KEYWORDS, ""}, + {"stack_snapshot", (PyCFunction)(void (*)(void))stack_snapshot, + METH_VARARGS | METH_KEYWORDS, ""}, + {NULL, NULL, 0, NULL}}; + +// Module definition +static PyModuleDef stack = {PyModuleDef_HEAD_INIT, "stack", NULL, -1, + StackTracesMethods}; + +// Module initialization +PyMODINIT_FUNC PyInit_stack(void) { + import_array(); // Initialize NumPy + return PyModule_Create(&stack); +} diff --git a/src/qseek/ext/stack.pyi b/src/qseek/ext/stack.pyi new file mode 100644 index 00000000..a057dfe5 --- /dev/null +++ b/src/qseek/ext/stack.pyi @@ -0,0 +1,25 @@ +import numpy as np + +def stack_traces( + traces: np.ndarray, + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + result: np.ndarray | None = None, + result_samples: int = 0, + n_threads: int = 1, +) -> tuple[np.ndarray, np.ndarray]: ... +def stack_and_reduce( + traces: np.ndarray, + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + n_threads: int = 1, +) -> tuple[np.ndarray, np.ndarray]: ... +def stack_snapshot( + traces: np.ndarray, + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + index: int, +) -> tuple[np.ndarray, np.ndarray]: ... diff --git a/src/qseek/ext_mojo/__init__.mojo b/src/qseek/ext_mojo/__init__.mojo new file mode 100644 index 00000000..e69de29b diff --git a/src/qseek/ext_mojo/__init__.py b/src/qseek/ext_mojo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qseek/ext_mojo/stack.py b/src/qseek/ext_mojo/stack.py new file mode 100644 index 00000000..4a77ce9c --- /dev/null +++ b/src/qseek/ext_mojo/stack.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import numpy as np + +from qseek.ext_mojo import stack_ext as stack + + +def stack_traces( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + result: list | None = None, + n_threads: int = 1, +) -> tuple[list, int]: + return stack.stack_traces(traces, offsets, shifts, weights, result, n_threads) + + +def stack_and_reduce( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + n_threads: int = 1, +) -> tuple[list, list, int]: + return stack.stack_and_reduce(traces, offsets, shifts, weights, n_threads) + + +def stack_snapshot( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + index: int, +) -> np.ndarray: + return stack.stack_snapshot(traces, offsets, shifts, weights, index) diff --git a/src/qseek/ext_mojo/stack_ext.mojo b/src/qseek/ext_mojo/stack_ext.mojo new file mode 100644 index 00000000..e42cc042 --- /dev/null +++ b/src/qseek/ext_mojo/stack_ext.mojo @@ -0,0 +1,491 @@ +from python import PythonObject, Python +from python.bindings import PythonModuleBuilder +from python.python import CPython +from sys.info import simd_width_of +from os import abort +from algorithm.functional import vectorize, parallelize +from memory.unsafe_pointer import UnsafePointer +from memory import memset +from sys.intrinsics import masked_store, prefetch, PrefetchOptions +from sys import num_logical_cores + +from layout import ( + LayoutTensor, + RuntimeLayout, + RuntimeTuple, + Layout, + IntTuple, + UNKNOWN_VALUE, +) + + +@export +fn PyInit_stack_ext() -> PythonObject: + try: + var m = PythonModuleBuilder("stack_traces") + m.def_function[stack_wrapper]( + "stack_traces", + docstring="", + ) + m.def_function[stack_and_reduce_wrapper]( + "stack_and_reduce", + docstring="", + ) + m.def_function[stack_snapshot_wrapper]( + "stack_snapshot", + docstring="", + ) + return m.finalize() + except e: + return abort[PythonObject]( + String("error creating Python Mojo module:", e) + ) + + +@fieldwise_init +struct Node[dtype: DType](Copyable & Movable): + var shifts: UnsafePointer[Int32] + var weights: UnsafePointer[Scalar[dtype]] + + +@fieldwise_init +struct NodeWithStack[dtype: DType](Copyable & Movable): + var shifts: UnsafePointer[Int32] + var weights: UnsafePointer[Scalar[dtype]] + var stack: UnsafePointer[Scalar[dtype]] + + def __init__( + out self, node: Node[dtype], stack: UnsafePointer[Scalar[dtype]] + ): + self.shifts = node.shifts + self.weights = node.weights + self.stack = stack + + +@fieldwise_init +struct Trace[dtype: DType](Copyable & Movable): + var data: UnsafePointer[Scalar[dtype]] + var size: Int + var offset: Int + + fn __init__(out self, array: PythonObject, offset: Int) raises: + if Int(array.ndim) != 1: + raise "Each trace must be a 1D array" + try: + check_array_dtype[dtype](array) + except: + raise "Input trace is of wrong dtype" + + self.data = array.ctypes.data.unsafe_get_as_pointer[dtype]() + self.size = Int(array.size) + self.offset = offset + + +@always_inline +fn get_thread_count(n_threads: Int) -> Int: + if n_threads <= 0: + return num_logical_cores() + return n_threads + + +@always_inline +fn get_dtype_char[dtype: DType]() raises -> String: + @parameter + if dtype is DType.float16: + return "e" + elif dtype is DType.float32: + return "f" + elif dtype is DType.float64: + return "d" + elif dtype is DType.int32: + return "i" + else: + raise "Unsupported dtype" + + +@always_inline +fn check_array_dtype[dtype: DType](numpy_array: PythonObject) raises: + @parameter + dtype_char = get_dtype_char[dtype]() + if String(numpy_array.dtype.char) != dtype_char: + raise "Input array must be of type " + String(dtype_char) + + +fn stack_wrapper( + traces: PythonObject, + offsets: PythonObject, + shifts: PythonObject, + weights: PythonObject, + result: PythonObject, + n_threads: PythonObject, +) raises -> PythonObject: + # np = Python.import_module("numpy") + + # n_traces = len(traces) + # if n_traces == 0: + # raise "Input traces must have positive dimensions" + # trace = traces[0] + + # if trace.dtype == np.float16: + # prepare_func = prepare[DType.float16] + # elif trace.dtype == np.float32: + # prepare_func = prepare[DType.float32] + # elif trace.dtype == np.float64: + # prepare_func = prepare[DType.float64] + # else: + # raise "Unsupported dtype: " + String(trace.dtype) + + traces_list, nodes_list, min_shift, max_shift = prepare[DType.float32]( + traces=traces, + offsets=offsets, + shifts=shifts, + weights=weights, + ) + return stack_traces( + traces_list, + nodes_list, + min_shift, + max_shift, + result, + get_thread_count(Int(n_threads)), + ) + + +fn stack_and_reduce_wrapper( + traces: PythonObject, + offsets: PythonObject, + shifts: PythonObject, + weights: PythonObject, + n_threads: PythonObject, +) raises -> PythonObject: + traces_list, nodes_list, min_shift, max_shift = prepare[DType.float32]( + traces=traces, + offsets=offsets, + shifts=shifts, + weights=weights, + ) + + return stack_and_reduce( + traces_list, + nodes_list, + min_shift, + max_shift, + get_thread_count(Int(n_threads)), + ) + + +fn stack_snapshot_wrapper( + traces: PythonObject, + offsets: PythonObject, + shifts: PythonObject, + weights: PythonObject, + index: PythonObject, +) raises -> PythonObject: + traces_list, nodes_list, min_shift, max_shift = prepare[DType.float32]( + traces=traces, + offsets=offsets, + shifts=shifts, + weights=weights, + ) + return stack_snapshot( + traces_list, + nodes_list, + min_shift, + max_shift, + Int(index), + ) + + +fn prepare[ + dtype: DType +]( + traces: PythonObject, + offsets: PythonObject, + shifts: PythonObject, + weights: PythonObject, +) raises -> Tuple[List[Trace[dtype]], List[Node[dtype]], Int32, Int32]: + np = Python.import_module("numpy") + + n_traces = len(traces) + n_nodes = Int(shifts.shape[0]) + + if n_nodes == 0: + raise "Input arrays must have positive dimensions" + if n_traces == 0: + raise "Input traces must have positive dimensions" + + check_array_dtype[dtype](weights) + check_array_dtype[DType.int32](offsets) + check_array_dtype[DType.int32](shifts) + + if shifts.shape != weights.shape: + raise "Shifts and weights must have the same shape" + if n_traces != Int(offsets.shape[0]): + raise "Number of arrays must match number of offsets" + if Int(shifts.shape[1]) != n_traces: + raise "Shifts must have the same number of columns as traces" + + offsets_data = offsets.ctypes.data.unsafe_get_as_pointer[DType.int32]() + shifts_data = shifts.ctypes.data.unsafe_get_as_pointer[DType.int32]() + weights_data = weights.ctypes.data.unsafe_get_as_pointer[dtype]() + + traces_list = List[Trace[dtype]](capacity=n_traces) + node_list = List[Node[dtype]](capacity=n_nodes) + + for i_trace in range(n_traces): + traces_list.append( + Trace[dtype]( + array=traces[i_trace], + offset=Int(offsets_data[i_trace]), + ) + ) + + for i_node in range(n_nodes): + node_list.append( + Node( + shifts=shifts_data + i_node * n_traces, + weights=weights_data + i_node * n_traces, + ) + ) + + min_shift = Int32.MAX + max_shift = Int32.MIN + for i_node in range(n_nodes): + node = node_list[i_node] + for i_trace in range(n_traces): + idx_begin = traces_list[i_trace].offset + node.shifts[i_trace] + idx_end = idx_begin + traces_list[i_trace].size + min_shift = min(min_shift, idx_begin) + max_shift = max(idx_end, max_shift) + + return ( + traces_list, + node_list, + min_shift, + max_shift, + ) + + +fn stack_traces[ + dtype: DType +]( + traces: List[Trace[dtype]], + nodes: List[Node[dtype]], + min_shift: Int32, + max_shift: Int32, + result_arr: PythonObject, + n_threads: Int = 16, +) raises -> PythonObject: + cpython = CPython() + np = Python.import_module("numpy") + + result_length = max_shift - min_shift + n_nodes = len(nodes) + n_traces = len(traces) + + result_shape = Python.tuple(n_nodes, result_length) + if result_arr is None: + result = np.zeros( + shape=Python.tuple(n_nodes, result_length), + dtype=get_dtype_char[dtype](), + ) + else: + if result_arr.shape != Python.tuple(n_nodes, result_length): + raise "Result array must have shape (n_nodes, length_out)" + check_array_dtype[dtype](result_arr) + result = result_arr + + result_data = result.ctypes.data.unsafe_get_as_pointer[dtype]() + node_list = List[NodeWithStack[dtype]](capacity=n_nodes) + for i_node in range(n_nodes): + node_list.append( + NodeWithStack[dtype]( + node=nodes[i_node], + stack=result_data + i_node * result_length, + ) + ) + + @parameter + fn stack_node(i_node: Int): + node = node_list[i_node] + # prefetch[PrefetchOptions().high_locality()](node.result) + + for i_trace in range(n_traces): + weight = node.weights[i_trace] + if weight == 0.0: + continue + trace = traces[i_trace] + trace_shift = trace.offset + node.shifts[i_trace] + base_idx = trace_shift - min_shift + + @parameter + fn stack[width: Int](i_sample: Int): + i_res = base_idx + i_sample + trace_samples = trace.data.load[width=width](i_sample) + stacked_samples = node.stack.load[width=width](i_res) + + stacked_samples += trace_samples * weight + node.stack.store(i_res, stacked_samples) + + stack_nsamples = min(result_length - base_idx, trace.size) + vectorize[stack, simd_width_of[dtype]()](Int(stack_nsamples)) + + state = cpython.PyGILState_Ensure() + parallelize[stack_node](n_nodes, n_threads) + cpython.PyGILState_Release(state) + + return Python.tuple(result, min_shift) + + +fn stack_and_reduce[ + dtype: DType +]( + traces: List[Trace[dtype]], + nodes: List[Node[dtype]], + min_shift: Int32, + max_shift: Int32, + n_threads: Int = 16, +) raises -> PythonObject: + cpython = CPython() + np = Python.import_module("numpy") + result_length = max_shift - min_shift + n_nodes = len(nodes) + n_traces = len(traces) + + node_max = np.full( + result_length, + np.finfo(np.float32).min, + dtype=get_dtype_char[dtype](), + ) + node_argmax = np.zeros(result_length, dtype=np.intp) + node_max_data = node_max.ctypes.data.unsafe_get_as_pointer[dtype]() + node_argmax_data = node_argmax.ctypes.data.unsafe_get_as_pointer[ + DType.uint64 + ]() + + @parameter + fn stack_tile(i_thread: Int): + tile_start_idx = i_thread * result_length // n_threads + tile_end_idx = (i_thread + 1) * result_length // n_threads + tile_size = Int(tile_end_idx - tile_start_idx) + + tile_node_stack = UnsafePointer[Scalar[dtype]].alloc(tile_size) + + for i_node in range(n_nodes): + node = nodes[i_node] + # prefetch[PrefetchOptions().high_locality()](node.node_max) + memset(tile_node_stack, 0, tile_size) + + for i_trace in range(n_traces): + weight = node.weights[i_trace] + + if weight == 0.0: + continue + + trace = traces[i_trace] + trace_shift = trace.offset + node.shifts[i_trace] + + base_idx = trace_shift - min_shift + tile_base_idx = max(0, base_idx - tile_start_idx) + + trace_start_idx = max(0, tile_start_idx - base_idx) + trace_end_idx = max(0, tile_end_idx - base_idx) + + trace_start_idx = min(trace_start_idx, trace.size) + trace_end_idx = min(trace_end_idx, trace.size) + + n_samples = trace_end_idx - trace_start_idx + + @parameter + fn stack[width: Int](i_sample: Int): + i_res = tile_base_idx + i_sample + + trace_samples = trace.data.load[width=width]( + trace_start_idx + i_sample + ) + stacked_samples = tile_node_stack.load[width=width](i_res) + + stacked_samples += trace_samples * weight + tile_node_stack.store(i_res, stacked_samples) + + vectorize[stack, simd_width_of[dtype]()](Int(n_samples)) + + @parameter + fn reduce_max[width: Int](idx: Int): + tile_idx = tile_start_idx + idx + node_vec = tile_node_stack.load[width=width](idx) + max_old_vec = node_max_data.load[width=width](tile_idx) + + max_new_vec = max(node_vec, max_old_vec) + update_mask = max_old_vec.ne(max_new_vec) + + node_max_data.store(tile_idx, max_new_vec) + masked_store( + SIMD[DType.uint64, width](i_node), + node_argmax_data + tile_idx, + mask=update_mask, + ) + + vectorize[reduce_max, simd_width_of[dtype]()](tile_size) + + tile_node_stack.free() + + state = cpython.PyGILState_Ensure() + parallelize[stack_tile](n_threads, n_threads) + cpython.PyGILState_Release(state) + + return Python.tuple(node_max, node_argmax, min_shift) + + +fn stack_snapshot[ + dtype: DType +]( + traces: List[Trace[dtype]], + nodes: List[Node[dtype]], + min_shift: Int32, + max_shift: Int32, + index: Int32, +) raises -> PythonObject: + cpython = CPython() + np = Python.import_module("numpy") + + result_length = max_shift - min_shift + if index >= result_length or index < 0: + raise "Snapshot index out of bounds: " + String(index) + + n_nodes = len(nodes) + n_traces = len(traces) + + result = np.zeros(shape=n_nodes, dtype=get_dtype_char[dtype]()) + result_data = result.ctypes.data.unsafe_get_as_pointer[dtype]() + + state = cpython.PyGILState_Ensure() + for i_node in range(n_nodes): + node = nodes[i_node] + + @parameter + fn stack_traces[width: Int](i_trace: Int): + trace_samples = SIMD[dtype, width](0) + trace_weights = SIMD[dtype, width](0) + + @parameter + for idx_vector in range(width): + trace_idx = i_trace + idx_vector + + weight = node.weights[trace_idx] + trace = traces[trace_idx] + trace_shift = trace.offset + node.shifts[trace_idx] + base_idx = trace_shift - min_shift + trace_sample = Int(index - base_idx) + + if 0 <= trace_sample < trace.size: + trace_samples[idx_vector] = trace.data[trace_sample] + trace_weights[idx_vector] = node.weights[trace_idx] + + result_data[i_node] += (trace_samples * trace_weights).reduce_add() + + vectorize[stack_traces, simd_width_of[dtype]()](Int(n_traces)) + + cpython.PyGILState_Release(state) + + return result diff --git a/src/qseek/ext_mojo/stack_ext.pyi b/src/qseek/ext_mojo/stack_ext.pyi new file mode 100644 index 00000000..e01180f2 --- /dev/null +++ b/src/qseek/ext_mojo/stack_ext.pyi @@ -0,0 +1,27 @@ +from __future__ import annotations + +import numpy as np + +def stack_traces( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + result: np.ndarray | None = None, + n_threads: int = 0, +) -> tuple[np.ndarray, np.ndarray]: ... +def stack_and_reduce( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + n_threads: int = 1, +) -> tuple[np.ndarray, np.ndarray, int]: ... +def stack_snapshot( + traces: list[np.ndarray], + offsets: np.ndarray, + shifts: np.ndarray, + weights: np.ndarray, + result: np.ndarray | None = None, + n_threads: int = 1, +) -> np.ndarray: ... diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 5c7a7808..5ae844bd 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -7,12 +7,11 @@ import numpy as np from pydantic import PrivateAttr, computed_field -from pyrocko import parstack from pyrocko.trace import Trace from rich.table import Table from scipy import signal -from qseek.ext import array_tools +from qseek.ext import array_tools, stack from qseek.ext.array_tools import fill_zero_bytes from qseek.stats import Stats from qseek.utils import datetime_now, get_cpu_count, human_readable_bytes @@ -354,15 +353,13 @@ async def add_semblance( start_time = datetime_now() _, offset_samples = await asyncio.to_thread( - parstack.parstack, - arrays=trace_data, + stack, + traces=trace_data, offsets=offsets, shifts=shifts, weights=weights, - lengthout=self.n_samples_unpadded, result=self.semblance_unpadded, - dtype=self.semblance_unpadded.dtype, - method=0, + result_samples=self.n_samples_unpadded, nparallel=threads, ) self._stats.add_stacking_time(datetime_now() - start_time, self.n_nodes) diff --git a/test/test_stack.py b/test/test_stack.py new file mode 100644 index 00000000..3cbffe5c --- /dev/null +++ b/test/test_stack.py @@ -0,0 +1,244 @@ +from itertools import product + +import numpy as np +import pytest +from pyrocko import parstack as pyrocko_parstack +from pytest import fixture + +from qseek.ext import array_tools, stack +from qseek.ext_mojo import stack as stack_mojo + +N_THREADS_TEST = [1, 4, 8] +ROUNDS = [2] + + +def get_data(n_nodes: int = 12 * 12 * 10, n_samples: int = 30_000, n_traces: int = 100): + # n_nodes = 20 + # n_samples = 500 + # n_traces = 1 + + traces = [] + for _ in range(n_traces): + traces.append(np.random.uniform(0, 32000, (n_samples)).astype(np.float32)) + + offsets = np.random.randint(-10, 10, n_traces, dtype=np.int32) + shifts = np.random.randint(-100, 100, size=(n_nodes, n_traces)).astype(np.int32) + weights = np.ones((n_nodes, n_traces), dtype=np.float32) + + return traces, offsets, shifts, weights + + +@fixture +def data_repeated(): + return get_data(n_traces=200) + + +@fixture +def data(): + return get_data(n_traces=100) + + +@pytest.mark.parametrize("n_threads", N_THREADS_TEST) +def test_stack(data, n_threads: int): + # TODO: Add mojo test + # TODO: Test lengthout / result_samples + traces, offsets, shifts, weights = data + res, offset = stack.stack_traces( + traces, + offsets, + shifts, + weights, + result=None, + n_threads=n_threads, + ) + res_pyrocko, offset_pyrocko = pyrocko_parstack.parstack( + traces, + offsets, + shifts, + weights, + dtype=np.float32, + method=0, + nparallel=n_threads, + ) + + np.testing.assert_allclose(res, res_pyrocko, rtol=1e-5) + assert offset == offset_pyrocko + + +@pytest.mark.parametrize("n_threads", N_THREADS_TEST) +def test_stack_and_reduce(data, n_threads: int): + # TODO: Add mojo test + traces, offsets, shifts, weights = data + max_res, max_node_idx, offset = stack.stack_and_reduce( + traces, + offsets, + shifts, + weights, + n_threads=n_threads, + ) + + res_pyrocko, offset_pyrocko = pyrocko_parstack.parstack( + traces, + offsets, + shifts, + weights, + result=None, + method=0, + nparallel=n_threads, + dtype=np.float32, + ) + pyrocko_max_node_idx, pyrocko_max_reduce = array_tools.argmax_masked( + res_pyrocko, + n_threads=n_threads, + ) + + assert offset == offset_pyrocko + np.testing.assert_allclose(max_res, pyrocko_max_reduce, rtol=1e-5) + np.testing.assert_equal(max_node_idx, pyrocko_max_node_idx) + + +@pytest.mark.stack +@pytest.mark.benchmark(group="stack") +@pytest.mark.parametrize("n_threads, rounds", list(product(N_THREADS_TEST, ROUNDS))) +def test_stack_simple_qseek(benchmark, data, n_threads: int, rounds: int): + traces, offsets, shifts, weights = data + + @benchmark + def run() -> None: + res = None + for _ in range(rounds): + res, _ = stack.stack_traces( + traces, + offsets, + shifts, + weights, + result=res, + n_threads=n_threads, + ) + array_tools.argmax_masked( + res, + n_threads=n_threads, + ) + + assert run is None + + +@pytest.mark.stack +@pytest.mark.benchmark(group="stack") +@pytest.mark.parametrize("n_threads, rounds", list(product(N_THREADS_TEST, ROUNDS))) +def test_stack_simple_pyrocko(benchmark, data, n_threads: int, rounds: int): + traces, offsets, shifts, weights = data + + @benchmark + def run() -> None: + res = None + for _ in range(rounds): + res, _ = pyrocko_parstack.parstack( + traces, + offsets, + shifts, + weights, + method=0, + result=res, + nparallel=n_threads, + dtype=np.float32, + ) + array_tools.argmax_masked( + res, + n_threads=n_threads, + ) + + assert run is None + + +@pytest.mark.stack +@pytest.mark.benchmark(group="stack") +@pytest.mark.parametrize("n_threads, rounds", list(product(N_THREADS_TEST, ROUNDS))) +def test_stack_simple_qseek_mojo(benchmark, data, n_threads: int, rounds: int): + traces, offsets, shifts, weights = data + + @benchmark + def run() -> None: + res = None + for _ in range(rounds): + res, _ = stack_mojo.stack_traces( + traces, + offsets, + shifts, + weights, + result=res, + n_threads=n_threads, + ) + array_tools.argmax_masked( + res, + n_threads=n_threads, + ) + + assert run is None + + +@pytest.mark.stack_reduce +@pytest.mark.benchmark(group="stack_reduce") +@pytest.mark.parametrize("n_threads", N_THREADS_TEST) +def test_stack_reduce_qseek(benchmark, data_repeated, n_threads: int): + traces, offsets, shifts, weights = data_repeated + benchmark( + stack.stack_and_reduce, + traces, + offsets, + shifts, + weights, + n_threads=n_threads, + ) + + +@pytest.mark.stack_reduce +@pytest.mark.benchmark(group="stack_reduce") +@pytest.mark.parametrize("n_threads", N_THREADS_TEST) +def test_stack_reduce_qseek_mojo(benchmark, data_repeated, n_threads: int): + traces, offsets, shifts, weights = data_repeated + benchmark( + stack_mojo.stack_and_reduce, + traces, + offsets, + shifts, + weights, + n_threads=n_threads, + ) + + +@pytest.mark.stack_reduce +@pytest.mark.benchmark(group="stack_reduce") +@pytest.mark.parametrize("n_threads", N_THREADS_TEST) +def test_stack_reduce_pyrocko(benchmark, data, n_threads: int): + traces, offsets, shifts, weights = data + + @benchmark + def reduce(): + res, _ = pyrocko_parstack.parstack( + traces, + offsets, + shifts, + weights, + result=None, + method=0, + nparallel=n_threads, + dtype=np.float32, + ) + res, _ = pyrocko_parstack.parstack( + traces, + offsets, + shifts, + weights, + result=res, + method=0, + nparallel=n_threads, + dtype=np.float32, + ) + + array_tools.argmax_masked( + res, + n_threads=n_threads, + ) + + assert reduce is None