Skip to content

Commit b5c608b

Browse files
committed
Refactor string formatting for improved readability in lowering_precision and batch_eth_mnist scripts
1 parent 2adfcba commit b5c608b

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

examples/benchmark/lowering_precision.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,47 @@
44
from statistics import mean
55

66
precision_sample_size = 4
7-
precisions = ['float16', 'float32']
7+
precisions = ["float16", "float32"]
88

99
folder = os.path.dirname(os.path.dirname(__file__))
10-
script = os.path.join(folder, 'mnist', 'batch_eth_mnist.py')
10+
script = os.path.join(folder, "mnist", "batch_eth_mnist.py")
1111
data = {}
1212
for precision in precisions:
1313
for _ in range(precision_sample_size):
1414
result = subprocess.run(
1515
f"python {script} --n_train 100 --batch_size 50 --n_test 10 --n_updates 1 --w_dtype {precision}",
16-
shell=True, capture_output=True, text=True
16+
shell=True,
17+
capture_output=True,
18+
text=True,
1719
)
1820
output = result.stdout
19-
time_match = re.search(r'Progress: 1 / 1 \((\d+\.\d+) seconds\)', output)
20-
memory_match = re.search(r'Memory consumption: (\d+)mb', output)
21-
data.setdefault(precision, []).append([
22-
time_match.groups()[0],
23-
memory_match.groups()[0]
24-
])
21+
time_match = re.search(r"Progress: 1 / 1 \((\d+\.\d+) seconds\)", output)
22+
memory_match = re.search(r"Memory consumption: (\d+)mb", output)
23+
data.setdefault(precision, []).append(
24+
[time_match.groups()[0], memory_match.groups()[0]]
25+
)
2526
print("+")
2627

2728

2829
def print_table(data):
2930
column_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
3031
for row in data:
31-
formatted_row = " | ".join(f"{str(item):<{column_widths[i]}}" for i, item in enumerate(row))
32+
formatted_row = " | ".join(
33+
f"{str(item):<{column_widths[i]}}" for i, item in enumerate(row)
34+
)
3235
print(formatted_row)
3336

3437

3538
average_time = {}
3639
average_memory = {}
3740
for precision, rows in data.items():
3841
print(f"precision: {precision}")
39-
table = [
40-
['Time (sec)', 'GPU memory (Mb)']
41-
] + rows
42+
table = [["Time (sec)", "GPU memory (Mb)"]] + rows
4243
avg_time = mean(map(lambda i: float(i[0]), rows))
4344
avg_memory = mean(map(lambda i: float(i[1]), rows))
4445
print_table(table)
4546
print(f"Average time: {avg_time}")
4647
print(f"Average memory: {avg_memory}")
4748
average_memory[precision] = avg_memory
4849
average_time[precision] = avg_time
49-
print('')
50+
print("")

examples/mnist/batch_eth_mnist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,9 @@
386386

387387
print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
388388
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))
389-
print(f"Memory consumption: {round(torch.cuda.max_memory_allocated(device=None) / 1024 ** 2)}mb")
389+
print(
390+
f"Memory consumption: {round(torch.cuda.max_memory_allocated(device=None) / 1024 ** 2)}mb"
391+
)
390392

391393
print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
392394
print("\nTesting complete.\n")

0 commit comments

Comments
 (0)