Skip to content

Commit 85fa6e7

Browse files
committed
hotfix : token duplicate
1 parent 99ba84b commit 85fa6e7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

lora_diffusion/lora_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
def lora_join(lora_safetenors: list):
1414
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
15+
_total_metadata = {}
1516
total_metadata = {}
1617
total_tensor = {}
1718
total_rank = 0
@@ -24,9 +25,14 @@ def lora_join(lora_safetenors: list):
2425

2526
assert len(set(rankset)) == 1, "Rank should be the same per model"
2627
total_rank += rankset[0]
27-
total_metadata.update(_metadata)
28+
_total_metadata.update(_metadata)
2829
ranklist.append(rankset[0])
2930

31+
# remove metadata about tokens
32+
for k, v in _total_metadata.items():
33+
if v != "<embed>":
34+
total_metadata[k] = v
35+
3036
tensorkeys = set()
3137
for safelora in lora_safetenors:
3238
tensorkeys.update(safelora.keys())
@@ -57,9 +63,6 @@ def lora_join(lora_safetenors: list):
5763

5864
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
5965

60-
if total_metadata.get(token, None) is not None:
61-
del total_metadata[token]
62-
6366
token_size_list.append(len(tokens))
6467

6568
return total_tensor, total_metadata, ranklist, token_size_list

0 commit comments

Comments
 (0)