Skip to content

Commit 00dfb9c

Browse files
author
joshistoast
committed
fix(prompt): improve numeric weighting calculation
1 parent 46570e9 commit 00dfb9c

File tree

2 files changed

+169
-23
lines changed

2 files changed

+169
-23
lines changed

invokeai/frontend/web/src/common/util/promptAttention.test.ts

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,108 @@ describe('adjustPromptAttention', () => {
168168
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('a+');
169169
});
170170
});
171+
172+
describe('numeric attention weights', () => {
173+
it('should preserve parentheses for single word with numeric weight when incrementing elsewhere', () => {
174+
const prompt = '(masterpiece)1.3, best quality';
175+
const len = prompt.length;
176+
// Select "best quality" and increment
177+
const bestQualityStart = prompt.indexOf('best quality');
178+
const result = adjustPromptAttention(prompt, bestQualityStart, len, 'increment');
179+
180+
// masterpiece should keep its parens and exact weight
181+
expect(result.prompt).toContain('(masterpiece)1.3');
182+
expect(result.prompt).not.toContain('masterpiece1.3');
183+
});
184+
185+
it('should not produce long floating point numbers for numeric weights', () => {
186+
const prompt = '(high detail)1.2, oil painting';
187+
const len = prompt.length;
188+
// Select "oil painting" and increment
189+
const oilStart = prompt.indexOf('oil painting');
190+
const result = adjustPromptAttention(prompt, oilStart, len, 'increment');
191+
192+
// high detail should keep its exact weight, no floating point garbage
193+
expect(result.prompt).toContain('(high detail)1.2');
194+
expect(result.prompt).not.toMatch(/1\.19999/);
195+
expect(result.prompt).not.toMatch(/1\.20000/);
196+
});
197+
198+
it('should preserve numeric weight 1.15 without floating point corruption', () => {
199+
const prompt = '(sunny midday light)1.15, landscape';
200+
const len = prompt.length;
201+
const landscapeStart = prompt.indexOf('landscape');
202+
const result = adjustPromptAttention(prompt, landscapeStart, len, 'increment');
203+
204+
expect(result.prompt).toContain('(sunny midday light)1.15');
205+
expect(result.prompt).not.toMatch(/1\.15005/);
206+
});
207+
208+
it('should normalize numeric 1.1 weight to + syntax', () => {
209+
const prompt = '(lush rolling hills)1.1, landscape';
210+
const len = prompt.length;
211+
const landscapeStart = prompt.indexOf('landscape');
212+
const result = adjustPromptAttention(prompt, landscapeStart, len, 'increment');
213+
214+
// 1.1 is equivalent to +, normalization is acceptable
215+
expect(result.prompt).toMatch(/\(lush rolling hills\)(\+|1\.1)/);
216+
});
217+
218+
it('should increment numeric weight correctly for single word', () => {
219+
const prompt = '(masterpiece)1.3';
220+
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
221+
222+
// 1.3 + 0.1 = 1.4
223+
expect(result.prompt).toBe('(masterpiece)1.4');
224+
});
225+
226+
it('should increment numeric weight correctly for multi-word group', () => {
227+
const prompt = '(high detail)1.2';
228+
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
229+
230+
// 1.2 + 0.1 = 1.3
231+
expect(result.prompt).toBe('(high detail)1.3');
232+
});
233+
234+
it('should decrement numeric weight correctly', () => {
235+
const prompt = '(masterpiece)1.3';
236+
const result = adjustPromptAttention(prompt, 0, prompt.length, 'decrement');
237+
238+
// 1.3 - 0.1 = 1.2
239+
expect(result.prompt).toBe('(masterpiece)1.2');
240+
});
241+
242+
it('should increment numeric weight 1.15 with additive step', () => {
243+
const prompt = '(sunny midday light)1.15';
244+
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
245+
246+
// 1.15 + 0.1 = 1.25
247+
expect(result.prompt).toBe('(sunny midday light)1.25');
248+
});
249+
250+
it('should decrement numeric weight 1.15 with additive step', () => {
251+
const prompt = '(sunny midday light)1.15';
252+
const result = adjustPromptAttention(prompt, 0, prompt.length, 'decrement');
253+
254+
// 1.15 - 0.1 = 1.05
255+
expect(result.prompt).toBe('(sunny midday light)1.05');
256+
});
257+
258+
it('should handle the full bug report prompt without corrupting non-selected weights', () => {
259+
const prompt =
260+
'(masterpiece)1.3, best quality, (high detail)1.2, oil painting, (sunny midday light)1.15, an old stone castle standing on a hill, medieval architecture, weathered stone walls, (lush rolling hills)1.1, expansive landscape, clear blue sky';
261+
const selStart = prompt.indexOf('clear blue sky');
262+
const selEnd = selStart + 'clear blue sky'.length;
263+
const result = adjustPromptAttention(prompt, selStart, selEnd, 'increment');
264+
265+
// Non-selected numeric weights must be preserved exactly
266+
expect(result.prompt).toContain('(masterpiece)1.3');
267+
expect(result.prompt).toContain('(high detail)1.2');
268+
expect(result.prompt).toContain('(sunny midday light)1.15');
269+
// Selected text should be incremented
270+
expect(result.prompt).toContain('(clear blue sky)+');
271+
// No floating point garbage anywhere
272+
expect(result.prompt).not.toMatch(/\d\.\d{5,}/);
273+
});
274+
});
171275
});

invokeai/frontend/web/src/common/util/promptAttention.ts

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ type AttentionDirection = 'increment' | 'decrement';
99
type AdjustmentResult = { prompt: string; selectionStart: number; selectionEnd: number };
1010

1111
const ATTENTION_STEP = 1.1;
12+
const NUMERIC_ATTENTION_STEP = 0.1;
13+
14+
/**
15+
* Check if a weight is approximately ATTENTION_STEP^n for some integer n.
16+
* Returns n if so, or null if the weight is not a power of ATTENTION_STEP.
17+
*/
18+
function getAttentionStepCount(weight: number): number | null {
19+
if (weight <= 0) {
20+
return null;
21+
}
22+
if (Math.abs(weight - 1.0) < 0.001) {
23+
return 0;
24+
}
25+
const n = Math.round(Math.log(weight) / Math.log(ATTENTION_STEP));
26+
if (n === 0) {
27+
return null;
28+
}
29+
const expected = Math.pow(ATTENTION_STEP, n);
30+
if (Math.abs(expected - weight) < 0.005) {
31+
return n;
32+
}
33+
return null;
34+
}
1235

1336
/**
1437
* Adjusts the attention of the prompt at the current cursor/selection position.
@@ -68,10 +91,20 @@ export function adjustPromptAttention(
6891
}
6992

7093
for (const terminal of selectedTerminals) {
71-
if (direction === 'increment') {
72-
terminal.weight *= ATTENTION_STEP;
94+
if (terminal.hasNumericAttention) {
95+
// Additive step for explicit numeric weights (e.g. 1.15 → 1.25 / 1.05)
96+
if (direction === 'increment') {
97+
terminal.weight = Number((terminal.weight + NUMERIC_ATTENTION_STEP).toFixed(4));
98+
} else {
99+
terminal.weight = Number((terminal.weight - NUMERIC_ATTENTION_STEP).toFixed(4));
100+
}
73101
} else {
74-
terminal.weight /= ATTENTION_STEP;
102+
// Multiplicative step for +/- syntax weights
103+
if (direction === 'increment') {
104+
terminal.weight *= ATTENTION_STEP;
105+
} else {
106+
terminal.weight /= ATTENTION_STEP;
107+
}
75108
}
76109
}
77110

@@ -97,28 +130,37 @@ type Terminal = {
97130
weight: number;
98131
range: { start: number; end: number };
99132
hasExplicitAttention: boolean;
133+
hasNumericAttention: boolean;
100134
parentRange?: { start: number; end: number };
101135
isSelected: boolean;
102136
};
103137

104-
function flattenAST(ast: ASTNode[], currentWeight = 1.0, parentRange?: { start: number; end: number }): Terminal[] {
138+
function flattenAST(
139+
ast: ASTNode[],
140+
currentWeight = 1.0,
141+
parentRange?: { start: number; end: number },
142+
numericAttention = false
143+
): Terminal[] {
105144
let terminals: Terminal[] = [];
106145

107146
for (const node of ast) {
108147
let nodeWeight = currentWeight;
148+
let nodeNumericAttention = numericAttention;
109149
if ('attention' in node && node.attention) {
110150
nodeWeight *= parseAttention(node.attention);
151+
nodeNumericAttention = typeof node.attention === 'number';
111152
}
112153

113154
if (node.type === 'group') {
114-
terminals.push(...flattenAST(node.children, nodeWeight, node.range));
155+
terminals.push(...flattenAST(node.children, nodeWeight, node.range, nodeNumericAttention));
115156
} else {
116157
terminals.push({
117158
text: node.type === 'word' ? node.text : node.value,
118159
type: node.type,
119160
weight: nodeWeight,
120161
range: node.range,
121162
hasExplicitAttention: 'attention' in node && !!node.attention,
163+
hasNumericAttention: nodeNumericAttention,
122164
parentRange: parentRange,
123165
isSelected: false,
124166
});
@@ -234,9 +276,14 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
234276
return j;
235277
};
236278

237-
// Check for + (>= 1.1)
238-
if (weight >= ATTENTION_STEP - 0.001) {
239-
const j = findRunEnd((w) => w >= ATTENTION_STEP - 0.001);
279+
const stepCount = getAttentionStepCount(weight);
280+
281+
// Check for + (positive power of ATTENTION_STEP)
282+
if (stepCount !== null && stepCount > 0) {
283+
const j = findRunEnd((w) => {
284+
const sc = getAttentionStepCount(w);
285+
return sc !== null && sc > 0;
286+
});
240287

241288
let runStart = i;
242289
let runEnd = j;
@@ -277,9 +324,12 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
277324
continue;
278325
}
279326

280-
// Check for - (<= 0.909)
281-
if (weight <= 1 / ATTENTION_STEP + 0.001) {
282-
const j = findRunEnd((w) => w <= 1 / ATTENTION_STEP + 0.001);
327+
// Check for - (negative power of ATTENTION_STEP)
328+
if (stepCount !== null && stepCount < 0) {
329+
const j = findRunEnd((w) => {
330+
const sc = getAttentionStepCount(w);
331+
return sc !== null && sc < 0;
332+
});
283333

284334
let runStart = i;
285335
let runEnd = j;
@@ -336,16 +386,8 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
336386

337387
const weightStr = Number(weight.toFixed(4));
338388

339-
if (children.length === 1) {
340-
const child = children[0]!;
341-
if (child.type === 'word' || child.type === 'group') {
342-
nodes.push({ ...child, attention: weightStr });
343-
} else {
344-
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
345-
}
346-
} else {
347-
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
348-
}
389+
// Always create a group for numeric weights to preserve parentheses in output
390+
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
349391
i = j;
350392
}
351393
}
@@ -377,10 +419,10 @@ function addAttention(current: Attention | undefined, added: string): Attention
377419
}
378420
if (typeof current === 'number') {
379421
if (added === '+') {
380-
return current * ATTENTION_STEP;
422+
return Number((current * ATTENTION_STEP).toFixed(4));
381423
}
382424
if (added === '-') {
383-
return current / ATTENTION_STEP;
425+
return Number((current / ATTENTION_STEP).toFixed(4));
384426
}
385427
return current;
386428
}

0 commit comments

Comments
 (0)