44from statistics import mean
55
66precision_sample_size = 4
7- precisions = [' float16' , ' float32' ]
7+ precisions = [" float16" , " float32" ]
88
99folder = os .path .dirname (os .path .dirname (__file__ ))
10- script = os .path .join (folder , ' mnist' , ' batch_eth_mnist.py' )
10+ script = os .path .join (folder , " mnist" , " batch_eth_mnist.py" )
1111data = {}
1212for precision in precisions :
1313 for _ in range (precision_sample_size ):
1414 result = subprocess .run (
1515 f"python { script } --n_train 100 --batch_size 50 --n_test 10 --n_updates 1 --w_dtype { precision } " ,
16- shell = True , capture_output = True , text = True
16+ shell = True ,
17+ capture_output = True ,
18+ text = True ,
1719 )
1820 output = result .stdout
19- time_match = re .search (r'Progress: 1 / 1 \((\d+\.\d+) seconds\)' , output )
20- memory_match = re .search (r'Memory consumption: (\d+)mb' , output )
21- data .setdefault (precision , []).append ([
22- time_match .groups ()[0 ],
23- memory_match .groups ()[0 ]
24- ])
21+ time_match = re .search (r"Progress: 1 / 1 \((\d+\.\d+) seconds\)" , output )
22+ memory_match = re .search (r"Memory consumption: (\d+)mb" , output )
23+ data .setdefault (precision , []).append (
24+ [time_match .groups ()[0 ], memory_match .groups ()[0 ]]
25+ )
2526 print ("+" )
2627
2728
2829def print_table (data ):
2930 column_widths = [max (len (str (item )) for item in col ) for col in zip (* data )]
3031 for row in data :
31- formatted_row = " | " .join (f"{ str (item ):<{column_widths [i ]}} " for i , item in enumerate (row ))
32+ formatted_row = " | " .join (
33+ f"{ str (item ):<{column_widths [i ]}} " for i , item in enumerate (row )
34+ )
3235 print (formatted_row )
3336
3437
3538average_time = {}
3639average_memory = {}
3740for precision , rows in data .items ():
3841 print (f"precision: { precision } " )
39- table = [
40- ['Time (sec)' , 'GPU memory (Mb)' ]
41- ] + rows
42+ table = [["Time (sec)" , "GPU memory (Mb)" ]] + rows
4243 avg_time = mean (map (lambda i : float (i [0 ]), rows ))
4344 avg_memory = mean (map (lambda i : float (i [1 ]), rows ))
4445 print_table (table )
4546 print (f"Average time: { avg_time } " )
4647 print (f"Average memory: { avg_memory } " )
4748 average_memory [precision ] = avg_memory
4849 average_time [precision ] = avg_time
49- print ('' )
50+ print ("" )
0 commit comments