64
64
from datetime import datetime
65
65
import gc
66
66
import json
67
+ import os
67
68
import random
68
69
import time
69
70
from typing import Any , AsyncGenerator , Optional
70
- import os
71
-
72
71
72
+ from benchmarks .eval_accuracy import eval_accuracy
73
+ from benchmarks .metrics import CounterMetric , EventMetric
73
74
import grpc
74
- from benchmarks .metrics import EventMetric , CounterMetric
75
75
from jetstream .core .proto import jetstream_pb2
76
76
from jetstream .core .proto import jetstream_pb2_grpc
77
77
from jetstream .engine .token_utils import load_vocab
78
78
from jetstream .external_tokenizers .llama3 import llama3_tokenizer
79
79
import numpy as np
80
- from tqdm .asyncio import tqdm # pytype: disable=pyi-error
81
80
import pandas
82
-
83
- from eval_accuracy import eval_accuracy
81
+ from tqdm .asyncio import tqdm # pytype: disable=pyi-error
84
82
from transformers import AutoTokenizer
85
83
86
84
@@ -706,136 +704,7 @@ def sample_warmup_requests(requests):
706
704
break
707
705
708
706
709
- def main (args : argparse .Namespace ):
710
- print (args )
711
- random .seed (args .seed )
712
- np .random .seed (args .seed )
713
-
714
- model_id = args .model
715
- tokenizer_id = args .tokenizer
716
- use_hf_tokenizer = args .use_hf_tokenizer
717
-
718
- prefill_quota = AsyncCounter (init_value = 3 )
719
- active_req_quota = AsyncCounter (init_value = 450 )
720
-
721
- api_url = f"{ args .server } :{ args .port } "
722
-
723
- tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
724
- if tokenizer == "test" or args .dataset == "test" :
725
- input_requests = mock_requests (
726
- args .total_mock_requests
727
- ) # e.g. [("AB", 2, "AB", 3)]
728
- else :
729
- dataset = []
730
- if args .dataset == "openorca" :
731
- dataset = load_openorca_dataset_pkl (args .dataset_path )
732
- elif args .dataset == "sharegpt" :
733
- dataset = load_sharegpt_dataset (
734
- args .dataset_path ,
735
- args .conversation_starter ,
736
- )
737
-
738
- # A given args.max_output_length value is the max generation step,
739
- # when the args.max_output_length is default to None, the sample's golden
740
- # output length will be used to decide the generation step.
741
- input_requests = sample_requests (
742
- dataset = dataset ,
743
- tokenizer = tokenizer ,
744
- num_requests = args .num_prompts ,
745
- max_output_length = args .max_output_length ,
746
- )
747
-
748
- warmup_requests = None
749
- if args .warmup_mode == "full" :
750
- warmup_requests = input_requests
751
- elif args .warmup_mode == "sampled" :
752
- warmup_requests = list (sample_warmup_requests (input_requests )) * 2
753
-
754
- if warmup_requests :
755
- print (f"Warmup (mode: { args .warmup_mode } ) is starting." )
756
- _ , _ = asyncio .run (
757
- benchmark (
758
- api_url = api_url ,
759
- tokenizer = tokenizer ,
760
- input_requests = warmup_requests ,
761
- request_rate = args .request_rate ,
762
- disable_tqdm = args .disable_tqdm ,
763
- prefill_quota = prefill_quota ,
764
- active_req_quota = active_req_quota ,
765
- is_warmup = True ,
766
- )
767
- )
768
- print (f"Warmup (mode: { args .warmup_mode } ) has completed." )
769
-
770
- # TODO: Replace this with warmup complete signal once supported.
771
- # Wait for server completely warmup before running the benchmark.
772
- time .sleep (5 )
773
-
774
- benchmark_result , request_outputs = asyncio .run (
775
- benchmark (
776
- api_url = api_url ,
777
- tokenizer = tokenizer ,
778
- input_requests = input_requests ,
779
- request_rate = args .request_rate ,
780
- disable_tqdm = args .disable_tqdm ,
781
- prefill_quota = prefill_quota ,
782
- active_req_quota = active_req_quota ,
783
- )
784
- )
785
-
786
- # Process output
787
- output = [output .to_dict () for output in request_outputs ]
788
- if args .run_eval :
789
- eval_json = eval_accuracy (output )
790
-
791
- # Save config and results to json
792
- if args .save_result :
793
- # dimensions values are strings
794
- dimensions_json = {}
795
- # metrics values are numerical
796
- metrics_json = {}
797
-
798
- # Setup
799
- current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
800
- dimensions_json ["date" ] = current_dt
801
- dimensions_json ["model_id" ] = model_id
802
- dimensions_json ["tokenizer_id" ] = tokenizer_id
803
- if args .additional_metadata_metrics_to_save is not None :
804
- dimensions_json = {
805
- ** dimensions_json ,
806
- ** json .loads (args .additional_metadata_metrics_to_save ),
807
- }
808
- metrics_json ["num_prompts" ] = args .num_prompts
809
-
810
- # Traffic
811
- metrics_json ["request_rate" ] = args .request_rate
812
- metrics_json = {** metrics_json , ** benchmark_result }
813
- if args .run_eval :
814
- metrics_json = {** metrics_json , ** eval_json }
815
-
816
- final_json = {}
817
- final_json ["metrics" ] = metrics_json
818
- final_json ["dimensions" ] = dimensions_json
819
-
820
- # Save to file
821
- base_model_id = model_id .split ("/" )[- 1 ]
822
- file_name = (
823
- f"JetStream-{ args .request_rate } qps-{ base_model_id } -{ current_dt } .json"
824
- )
825
- with open (file_name , "w" , encoding = "utf-8" ) as outfile :
826
- json .dump (final_json , outfile )
827
-
828
- if args .save_request_outputs :
829
- file_path = args .request_outputs_file_path
830
- with open (file_path , "w" , encoding = "utf-8" ) as output_file :
831
- json .dump (
832
- output ,
833
- output_file ,
834
- indent = 4 ,
835
- )
836
-
837
-
838
- if __name__ == "__main__" :
707
+ def parse_args () -> argparse .Namespace :
839
708
parser = argparse .ArgumentParser (
840
709
description = "Benchmark the online serving throughput."
841
710
)
@@ -909,7 +778,6 @@ def main(args: argparse.Namespace):
909
778
default = 150 ,
910
779
help = "The maximum number of mock requests to send for benchmark testing." ,
911
780
)
912
-
913
781
parser .add_argument (
914
782
"--max-output-length" ,
915
783
type = int ,
@@ -926,7 +794,6 @@ def main(args: argparse.Namespace):
926
794
"the output length of the golden dataset would be passed."
927
795
),
928
796
)
929
-
930
797
parser .add_argument ("--seed" , type = int , default = 0 )
931
798
parser .add_argument (
932
799
"--disable-tqdm" ,
@@ -977,7 +844,138 @@ def main(args: argparse.Namespace):
977
844
choices = ["human" , "gpt" , "both" ],
978
845
help = "What entity should be the one starting the conversations." ,
979
846
)
847
+ return parser .parse_args ()
848
+
849
+
850
+ def main (args : argparse .Namespace ):
851
+ print (args )
852
+ random .seed (args .seed )
853
+ np .random .seed (args .seed )
854
+
855
+ model_id = args .model
856
+ tokenizer_id = args .tokenizer
857
+ use_hf_tokenizer = args .use_hf_tokenizer
858
+
859
+ prefill_quota = AsyncCounter (init_value = 3 )
860
+ active_req_quota = AsyncCounter (init_value = 450 )
861
+
862
+ api_url = f"{ args .server } :{ args .port } "
863
+
864
+ tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
865
+ if tokenizer == "test" or args .dataset == "test" :
866
+ input_requests = mock_requests (
867
+ args .total_mock_requests
868
+ ) # e.g. [("AB", 2, "AB", 3)]
869
+ else :
870
+ dataset = []
871
+ if args .dataset == "openorca" :
872
+ dataset = load_openorca_dataset_pkl (args .dataset_path )
873
+ elif args .dataset == "sharegpt" :
874
+ dataset = load_sharegpt_dataset (
875
+ args .dataset_path ,
876
+ args .conversation_starter ,
877
+ )
878
+
879
+ # A given args.max_output_length value is the max generation step,
880
+ # when the args.max_output_length is default to None, the sample's golden
881
+ # output length will be used to decide the generation step.
882
+ input_requests = sample_requests (
883
+ dataset = dataset ,
884
+ tokenizer = tokenizer ,
885
+ num_requests = args .num_prompts ,
886
+ max_output_length = args .max_output_length ,
887
+ )
888
+
889
+ warmup_requests = None
890
+ if args .warmup_mode == "full" :
891
+ warmup_requests = input_requests
892
+ elif args .warmup_mode == "sampled" :
893
+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
894
+
895
+ if warmup_requests :
896
+ print (f"Warmup (mode: { args .warmup_mode } ) is starting." )
897
+ _ , _ = asyncio .run (
898
+ benchmark (
899
+ api_url = api_url ,
900
+ tokenizer = tokenizer ,
901
+ input_requests = warmup_requests ,
902
+ request_rate = args .request_rate ,
903
+ disable_tqdm = args .disable_tqdm ,
904
+ prefill_quota = prefill_quota ,
905
+ active_req_quota = active_req_quota ,
906
+ is_warmup = True ,
907
+ )
908
+ )
909
+ print (f"Warmup (mode: { args .warmup_mode } ) has completed." )
910
+
911
+ # TODO: Replace this with warmup complete signal once supported.
912
+ # Wait for server completely warmup before running the benchmark.
913
+ time .sleep (5 )
914
+
915
+ benchmark_result , request_outputs = asyncio .run (
916
+ benchmark (
917
+ api_url = api_url ,
918
+ tokenizer = tokenizer ,
919
+ input_requests = input_requests ,
920
+ request_rate = args .request_rate ,
921
+ disable_tqdm = args .disable_tqdm ,
922
+ prefill_quota = prefill_quota ,
923
+ active_req_quota = active_req_quota ,
924
+ )
925
+ )
926
+
927
+ # Process output
928
+ output = [output .to_dict () for output in request_outputs ]
929
+ if args .run_eval :
930
+ eval_json = eval_accuracy (output )
931
+
932
+ # Save config and results to json
933
+ if args .save_result :
934
+ # dimensions values are strings
935
+ dimensions_json = {}
936
+ # metrics values are numerical
937
+ metrics_json = {}
980
938
981
- parsed_args = parser .parse_args ()
939
+ # Setup
940
+ current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
941
+ dimensions_json ["date" ] = current_dt
942
+ dimensions_json ["model_id" ] = model_id
943
+ dimensions_json ["tokenizer_id" ] = tokenizer_id
944
+ if args .additional_metadata_metrics_to_save is not None :
945
+ dimensions_json = {
946
+ ** dimensions_json ,
947
+ ** json .loads (args .additional_metadata_metrics_to_save ),
948
+ }
949
+ metrics_json ["num_prompts" ] = args .num_prompts
950
+
951
+ # Traffic
952
+ metrics_json ["request_rate" ] = args .request_rate
953
+ metrics_json = {** metrics_json , ** benchmark_result }
954
+ if args .run_eval :
955
+ metrics_json = {** metrics_json , ** eval_json }
956
+
957
+ final_json = {}
958
+ final_json ["metrics" ] = metrics_json
959
+ final_json ["dimensions" ] = dimensions_json
960
+
961
+ # Save to file
962
+ base_model_id = model_id .split ("/" )[- 1 ]
963
+ file_name = (
964
+ f"JetStream-{ args .request_rate } qps-{ base_model_id } -{ current_dt } .json"
965
+ )
966
+ with open (file_name , "w" , encoding = "utf-8" ) as outfile :
967
+ json .dump (final_json , outfile )
968
+
969
+ if args .save_request_outputs :
970
+ file_path = args .request_outputs_file_path
971
+ with open (file_path , "w" , encoding = "utf-8" ) as output_file :
972
+ json .dump (
973
+ output ,
974
+ output_file ,
975
+ indent = 4 ,
976
+ )
977
+
978
+
979
+ if __name__ == "__main__" :
982
980
gc .disable ()
983
- main (parsed_args )
981
+ main (parse_args () )
0 commit comments