Skip to content

Commit 33cc2a2

Browse files
committed
fix RNG and GEMM issues
1 parent 3f16652 commit 33cc2a2

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

applications/llama_3.2_1b/inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ def inference(
138138
)
139139
logging.info("Model and tokenizer loaded.")
140140

141+
# Important: Set the seed again after initialization of the model. Each
142+
# call that initializes an nn.Linear layer updates the RNG state, because
143+
# weights are initialized with random values. For different JSON
144+
# configurations, we initialize a different number of linear layers,
145+
# so different configurations result in a different RNG state here. Since
146+
# we use random numbers to sample from the token distribution during
147+
# inference, it is important to have the same RNG state between runs so we
148+
# can have reproducible results across configurations.
149+
torch.manual_seed(1608560892)
150+
141151
hook_handles = []
142152
if save_outputs:
143153
if os.path.exists(output_data_path):

applications/llama_3.2_1b/src/operator/aie_gemm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,19 @@ def __init__(
8282

8383
def set_up(self):
8484
# Describe required artifacts (xclbin, insts.bin)
85-
file_name_tile_base = f"{self.tile_m}x{self.tile_k}x{self.tile_n}"
85+
# file_name_tile_base = f"{self.tile_m}x{self.tile_k}x{self.tile_n}"
8686
file_name_total_base = (
8787
f"{self.M}x{self.K}x{self.N}_{self.tile_m}x{self.tile_k}x{self.tile_n}"
8888
)
89+
# FIXME: We should be able to reuse the same xclbin for same tile
90+
# sizes, only swapping out the instruction sequence for different
91+
# problem sizes. However, there seem to be cases where this does
92+
# not work and the GEMM appears to be misconfigured for the wrong
93+
# size (resulting in a timeout when trying to run it). Perhaps
94+
# XRT is caching something, or something is wrong with the run-
95+
# time parameter (synchronization)? For now, create separate
96+
# xclbins for each problem size.
97+
file_name_tile_base = file_name_total_base
8998
xclbin_kernel_name = f"gemm_{file_name_tile_base}"
9099
kernel_flags = [
91100
f"-DDIM_M={self.tile_m}",

0 commit comments

Comments
 (0)