Skip to content

Commit 270b4cc

Browse files
authored
Merge branch 'AI-Hypercomputer:main' into logging
2 parents 977e2a4 + cd44c56 commit 270b4cc

39 files changed

+2459
-411
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ unit-tests:
5151
coverage run -m unittest -v
5252

5353
check-test-coverage:
54-
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*" --fail-under=96
54+
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py" --fail-under=96

benchmarks/benchmark_serving.py

Lines changed: 138 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,21 @@
6464
from datetime import datetime
6565
import gc
6666
import json
67+
import os
6768
import random
6869
import time
6970
from typing import Any, AsyncGenerator, Optional
70-
import os
71-
7271

72+
from benchmarks.eval_accuracy import eval_accuracy
73+
from benchmarks.metrics import CounterMetric, EventMetric
7374
import grpc
74-
from benchmarks.metrics import EventMetric, CounterMetric
7575
from jetstream.core.proto import jetstream_pb2
7676
from jetstream.core.proto import jetstream_pb2_grpc
7777
from jetstream.engine.token_utils import load_vocab
7878
from jetstream.external_tokenizers.llama3 import llama3_tokenizer
7979
import numpy as np
80-
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8180
import pandas
82-
83-
from eval_accuracy import eval_accuracy
81+
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
8482
from transformers import AutoTokenizer
8583

8684

@@ -706,136 +704,7 @@ def sample_warmup_requests(requests):
706704
break
707705

708706

709-
def main(args: argparse.Namespace):
710-
print(args)
711-
random.seed(args.seed)
712-
np.random.seed(args.seed)
713-
714-
model_id = args.model
715-
tokenizer_id = args.tokenizer
716-
use_hf_tokenizer = args.use_hf_tokenizer
717-
718-
prefill_quota = AsyncCounter(init_value=3)
719-
active_req_quota = AsyncCounter(init_value=450)
720-
721-
api_url = f"{args.server}:{args.port}"
722-
723-
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
724-
if tokenizer == "test" or args.dataset == "test":
725-
input_requests = mock_requests(
726-
args.total_mock_requests
727-
) # e.g. [("AB", 2, "AB", 3)]
728-
else:
729-
dataset = []
730-
if args.dataset == "openorca":
731-
dataset = load_openorca_dataset_pkl(args.dataset_path)
732-
elif args.dataset == "sharegpt":
733-
dataset = load_sharegpt_dataset(
734-
args.dataset_path,
735-
args.conversation_starter,
736-
)
737-
738-
# A given args.max_output_length value is the max generation step,
739-
# when the args.max_output_length is default to None, the sample's golden
740-
# output length will be used to decide the generation step.
741-
input_requests = sample_requests(
742-
dataset=dataset,
743-
tokenizer=tokenizer,
744-
num_requests=args.num_prompts,
745-
max_output_length=args.max_output_length,
746-
)
747-
748-
warmup_requests = None
749-
if args.warmup_mode == "full":
750-
warmup_requests = input_requests
751-
elif args.warmup_mode == "sampled":
752-
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
753-
754-
if warmup_requests:
755-
print(f"Warmup (mode: {args.warmup_mode}) is starting.")
756-
_, _ = asyncio.run(
757-
benchmark(
758-
api_url=api_url,
759-
tokenizer=tokenizer,
760-
input_requests=warmup_requests,
761-
request_rate=args.request_rate,
762-
disable_tqdm=args.disable_tqdm,
763-
prefill_quota=prefill_quota,
764-
active_req_quota=active_req_quota,
765-
is_warmup=True,
766-
)
767-
)
768-
print(f"Warmup (mode: {args.warmup_mode}) has completed.")
769-
770-
# TODO: Replace this with warmup complete signal once supported.
771-
# Wait for server completely warmup before running the benchmark.
772-
time.sleep(5)
773-
774-
benchmark_result, request_outputs = asyncio.run(
775-
benchmark(
776-
api_url=api_url,
777-
tokenizer=tokenizer,
778-
input_requests=input_requests,
779-
request_rate=args.request_rate,
780-
disable_tqdm=args.disable_tqdm,
781-
prefill_quota=prefill_quota,
782-
active_req_quota=active_req_quota,
783-
)
784-
)
785-
786-
# Process output
787-
output = [output.to_dict() for output in request_outputs]
788-
if args.run_eval:
789-
eval_json = eval_accuracy(output)
790-
791-
# Save config and results to json
792-
if args.save_result:
793-
# dimensions values are strings
794-
dimensions_json = {}
795-
# metrics values are numerical
796-
metrics_json = {}
797-
798-
# Setup
799-
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
800-
dimensions_json["date"] = current_dt
801-
dimensions_json["model_id"] = model_id
802-
dimensions_json["tokenizer_id"] = tokenizer_id
803-
if args.additional_metadata_metrics_to_save is not None:
804-
dimensions_json = {
805-
**dimensions_json,
806-
**json.loads(args.additional_metadata_metrics_to_save),
807-
}
808-
metrics_json["num_prompts"] = args.num_prompts
809-
810-
# Traffic
811-
metrics_json["request_rate"] = args.request_rate
812-
metrics_json = {**metrics_json, **benchmark_result}
813-
if args.run_eval:
814-
metrics_json = {**metrics_json, **eval_json}
815-
816-
final_json = {}
817-
final_json["metrics"] = metrics_json
818-
final_json["dimensions"] = dimensions_json
819-
820-
# Save to file
821-
base_model_id = model_id.split("/")[-1]
822-
file_name = (
823-
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
824-
)
825-
with open(file_name, "w", encoding="utf-8") as outfile:
826-
json.dump(final_json, outfile)
827-
828-
if args.save_request_outputs:
829-
file_path = args.request_outputs_file_path
830-
with open(file_path, "w", encoding="utf-8") as output_file:
831-
json.dump(
832-
output,
833-
output_file,
834-
indent=4,
835-
)
836-
837-
838-
if __name__ == "__main__":
707+
def parse_args() -> argparse.Namespace:
839708
parser = argparse.ArgumentParser(
840709
description="Benchmark the online serving throughput."
841710
)
@@ -909,7 +778,6 @@ def main(args: argparse.Namespace):
909778
default=150,
910779
help="The maximum number of mock requests to send for benchmark testing.",
911780
)
912-
913781
parser.add_argument(
914782
"--max-output-length",
915783
type=int,
@@ -926,7 +794,6 @@ def main(args: argparse.Namespace):
926794
"the output length of the golden dataset would be passed."
927795
),
928796
)
929-
930797
parser.add_argument("--seed", type=int, default=0)
931798
parser.add_argument(
932799
"--disable-tqdm",
@@ -977,7 +844,138 @@ def main(args: argparse.Namespace):
977844
choices=["human", "gpt", "both"],
978845
help="What entity should be the one starting the conversations.",
979846
)
847+
return parser.parse_args()
848+
849+
850+
def main(args: argparse.Namespace):
851+
print(args)
852+
random.seed(args.seed)
853+
np.random.seed(args.seed)
854+
855+
model_id = args.model
856+
tokenizer_id = args.tokenizer
857+
use_hf_tokenizer = args.use_hf_tokenizer
858+
859+
prefill_quota = AsyncCounter(init_value=3)
860+
active_req_quota = AsyncCounter(init_value=450)
861+
862+
api_url = f"{args.server}:{args.port}"
863+
864+
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
865+
if tokenizer == "test" or args.dataset == "test":
866+
input_requests = mock_requests(
867+
args.total_mock_requests
868+
) # e.g. [("AB", 2, "AB", 3)]
869+
else:
870+
dataset = []
871+
if args.dataset == "openorca":
872+
dataset = load_openorca_dataset_pkl(args.dataset_path)
873+
elif args.dataset == "sharegpt":
874+
dataset = load_sharegpt_dataset(
875+
args.dataset_path,
876+
args.conversation_starter,
877+
)
878+
879+
# A given args.max_output_length value is the max generation step,
880+
# when the args.max_output_length is default to None, the sample's golden
881+
# output length will be used to decide the generation step.
882+
input_requests = sample_requests(
883+
dataset=dataset,
884+
tokenizer=tokenizer,
885+
num_requests=args.num_prompts,
886+
max_output_length=args.max_output_length,
887+
)
888+
889+
warmup_requests = None
890+
if args.warmup_mode == "full":
891+
warmup_requests = input_requests
892+
elif args.warmup_mode == "sampled":
893+
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
894+
895+
if warmup_requests:
896+
print(f"Warmup (mode: {args.warmup_mode}) is starting.")
897+
_, _ = asyncio.run(
898+
benchmark(
899+
api_url=api_url,
900+
tokenizer=tokenizer,
901+
input_requests=warmup_requests,
902+
request_rate=args.request_rate,
903+
disable_tqdm=args.disable_tqdm,
904+
prefill_quota=prefill_quota,
905+
active_req_quota=active_req_quota,
906+
is_warmup=True,
907+
)
908+
)
909+
print(f"Warmup (mode: {args.warmup_mode}) has completed.")
910+
911+
# TODO: Replace this with warmup complete signal once supported.
912+
# Wait for server completely warmup before running the benchmark.
913+
time.sleep(5)
914+
915+
benchmark_result, request_outputs = asyncio.run(
916+
benchmark(
917+
api_url=api_url,
918+
tokenizer=tokenizer,
919+
input_requests=input_requests,
920+
request_rate=args.request_rate,
921+
disable_tqdm=args.disable_tqdm,
922+
prefill_quota=prefill_quota,
923+
active_req_quota=active_req_quota,
924+
)
925+
)
926+
927+
# Process output
928+
output = [output.to_dict() for output in request_outputs]
929+
if args.run_eval:
930+
eval_json = eval_accuracy(output)
931+
932+
# Save config and results to json
933+
if args.save_result:
934+
# dimensions values are strings
935+
dimensions_json = {}
936+
# metrics values are numerical
937+
metrics_json = {}
980938

981-
parsed_args = parser.parse_args()
939+
# Setup
940+
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
941+
dimensions_json["date"] = current_dt
942+
dimensions_json["model_id"] = model_id
943+
dimensions_json["tokenizer_id"] = tokenizer_id
944+
if args.additional_metadata_metrics_to_save is not None:
945+
dimensions_json = {
946+
**dimensions_json,
947+
**json.loads(args.additional_metadata_metrics_to_save),
948+
}
949+
metrics_json["num_prompts"] = args.num_prompts
950+
951+
# Traffic
952+
metrics_json["request_rate"] = args.request_rate
953+
metrics_json = {**metrics_json, **benchmark_result}
954+
if args.run_eval:
955+
metrics_json = {**metrics_json, **eval_json}
956+
957+
final_json = {}
958+
final_json["metrics"] = metrics_json
959+
final_json["dimensions"] = dimensions_json
960+
961+
# Save to file
962+
base_model_id = model_id.split("/")[-1]
963+
file_name = (
964+
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
965+
)
966+
with open(file_name, "w", encoding="utf-8") as outfile:
967+
json.dump(final_json, outfile)
968+
969+
if args.save_request_outputs:
970+
file_path = args.request_outputs_file_path
971+
with open(file_path, "w", encoding="utf-8") as output_file:
972+
json.dump(
973+
output,
974+
output_file,
975+
indent=4,
976+
)
977+
978+
979+
if __name__ == "__main__":
982980
gc.disable()
983-
main(parsed_args)
981+
main(parse_args())

0 commit comments

Comments
 (0)