diff --git a/llama3.cu b/llama3.cu index e1e8406..ca18714 100644 --- a/llama3.cu +++ b/llama3.cu @@ -783,16 +783,6 @@ int sample_argmax(float *probabilities, int n) { return max_i; } -// ---------------------------------------------------------------------------- -// utilities: time -// ---------------------------------------------------------------------------- -long time_in_ms() { - // return time in milliseconds, for benchmarking the model speed - struct timespec time; - clock_gettime(CLOCK_REALTIME, &time); - return time.tv_sec * 1000 + time.tv_nsec / 1000000; -} - // ---------------------------------------------------------------------------- // generation loop // ---------------------------------------------------------------------------- @@ -812,10 +802,14 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, char *prompt, int } // start the main loop - long start = 0; // used to time our code, only initialized after first iteration + cudaEvent_t start, stop; // CUDA events for measuring performance + cudaEventCreate(&start); + cudaEventCreate(&stop); + int next; // will store the next token in the sequence int token = prompt_tokens[0]; // kick off with the first token in the prompt int pos = 0; // position in the sequence + cudaEventRecord(start); while (pos < max_new_tokens - 1) { // forward the transformer to get logits for the next token float *logits = forward(transformer, token, pos); @@ -837,19 +831,20 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, char *prompt, int safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); token = next; - - // init the timer here because the first iteration can be slower - if (start == 0) { start = time_in_ms(); } } + cudaEventRecord(stop); printf("\n"); // report achieved tok/s (Token count is assumed to be pos+1 because BOS token must be included) if (pos > 1) { - long end = time_in_ms(); + cudaEventSynchronize(stop); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); fprintf(stderr, "Token count: %d, elapsed: %fs, %d tokens/s\n", - pos + 1, (float) (end - start) / 1000, (int) ((pos - 1) / (double) (end - start) * 1000)); + pos + 1, milliseconds / 1000, (int) ((pos - 1) / milliseconds * 1000)); } - + cudaEventDestroy(start); + cudaEventDestroy(stop); free(prompt_tokens); }