@@ -49,43 +49,37 @@ static const float NEAR_SCALE = 1.2f; // foreground blur emphasis
4949static const float FAR_SCALE = 1.0f ; // background blur scale
5050static const float BG_LEAK_PREVENTION = 0.5f ; // reduce background bleeding into foreground
5151
52+ // derived constants (compile-time)
53+ static const float COC_CLAMP_PIXELS = MAX_COC_RADIUS * COC_CLAMP_FACTOR;
54+ static const float INV_SAMPLE_COUNT = 1.0f / (float )SAMPLE_COUNT;
55+ static const float INV_SCATTER_NORM = 1.0f / (MAX_COC_RADIUS * 0.3f );
56+ static const float INV_OUTLIER_THRES = 1.0f / OUTLIER_THRESHOLD;
57+ static const float INV_FOCUS_REGION = 1.0f / FOCUS_REGION;
58+
5259/*------------------------------------------------------------------------------
53- circle of confusion calculation using thin lens model
54-
55- formula: coc = |A * f * (s - d) / (d * (s - f))|
56- where: A = aperture diameter, f = focal length
57- s = focus distance, d = pixel depth
60+ lens constants computed once per group then read by every thread
5861------------------------------------------------------------------------------*/
59- float compute_coc ( float pixel_depth, float focus_distance, float aperture_fstop, float2 resolution)
62+ struct lens_t
6063{
61- // convert focal length to meters
62- float f = FOCAL_LENGTH_MM * 0.001f ;
63-
64- // clamp distances to valid range
65- float s = max (focus_distance, f + 0.01f ); // focus distance
66- float d = max (pixel_depth, 0.01f ); // pixel depth
67-
68- // aperture diameter from f-stop: diameter = focal_length / f-stop
69- float aperture_diameter = f / max (aperture_fstop, 1.0f );
70-
71- // thin lens coc in meters
72- float coc_m = abs (aperture_diameter * (s - d) * f) / (d * abs (s - f) + FLT_MIN);
73-
74- // convert to pixels: pixels = coc_meters * (resolution / sensor_size)
75- float sensor_m = SENSOR_HEIGHT_MM * 0.001f ;
76- float coc_pixels = coc_m * (resolution.y / sensor_m);
77-
78- // sign indicates near (-) vs far (+) field
79- float sign = (d < s) ? -1.0f : 1.0f ;
80-
81- // scale differently for near/far
82- float scale = (sign < 0.0f ) ? NEAR_SCALE : FAR_SCALE;
83- coc_pixels *= scale;
84-
85- // clamp to reasonable range
86- coc_pixels = min (abs (coc_pixels), MAX_COC_RADIUS * COC_CLAMP_FACTOR);
87-
88- return sign * coc_pixels;
64+ float focus_distance; // s, focus distance in meters clamped above focal length
65+ float coc_factor; // aperture_diameter * f * resolution.y / sensor_m / abs(s - f)
66+ };
67+
68+ groupshared lens_t gs_lens;
69+
70+ /*------------------------------------------------------------------------------
71+ fast signed coc using precomputed lens factor
72+ sign is negative for foreground (d < s) and positive for background (d > s)
73+ ------------------------------------------------------------------------------*/
74+ float compute_coc_signed (float depth, lens_t lens)
75+ {
76+ float d = max (depth, 0.01f );
77+ float s_minus_d = lens.focus_distance - d;
78+ float coc_pix = abs (s_minus_d) * lens.coc_factor / d;
79+ bool is_near = s_minus_d > 0.0f ;
80+ float scale = is_near ? NEAR_SCALE : FAR_SCALE;
81+ coc_pix = min (coc_pix * scale, COC_CLAMP_PIXELS);
82+ return is_near ? -coc_pix : coc_pix;
8983}
9084
9185/*------------------------------------------------------------------------------
@@ -98,139 +92,101 @@ float compute_focus_distance(float2 resolution)
9892{
9993 float2 center = float2 (0.5f , 0.5f );
10094
101- // collect depth samples
10295 float depths[FOCUS_SAMPLES];
10396 float weights[FOCUS_SAMPLES];
10497 float weight_sum = 0.0f ;
10598
106- // spiral sampling with golden angle
99+ [unroll]
107100 for (int i = 0 ; i < FOCUS_SAMPLES; i++)
108101 {
109- float t = (float )i / (float )(FOCUS_SAMPLES - 1 );
110- float angle = i * GOLDEN_ANGLE;
102+ float t = (float )i / (float )(FOCUS_SAMPLES - 1 );
103+ float angle = i * GOLDEN_ANGLE;
111104 float radius = sqrt (t) * FOCUS_REGION;
112105
113- float2 offset = float2 (cos (angle), sin (angle)) * radius;
114- float2 uv = center + offset;
106+ float sin_a, cos_a;
107+ sincos (angle, sin_a, cos_a);
108+ float2 offset = float2 (cos_a, sin_a) * radius;
109+ float2 uv = center + offset;
115110
116- depths[i] = get_linear_depth (uv * buffer_frame.resolution_scale);
117-
118- // weight by proximity to center (gaussian-ish falloff)
119- float dist = length (offset) / FOCUS_REGION;
120- weights[i] = exp (-dist * dist * CENTER_WEIGHT_BIAS);
121- weight_sum += weights[i];
111+ depths[i] = get_linear_depth (uv * buffer_frame.resolution_scale);
112+ float dist_n = length (offset) * INV_FOCUS_REGION;
113+ weights[i] = exp (-dist_n * dist_n * CENTER_WEIGHT_BIAS);
114+ weight_sum += weights[i];
122115 }
123116
124- // first pass: weighted average for approximate median
125117 float weighted_avg = 0.0f ;
126- for (int i = 0 ; i < FOCUS_SAMPLES; i++)
118+ [unroll]
119+ for (int j = 0 ; j < FOCUS_SAMPLES; j++)
127120 {
128- weighted_avg += depths[i ] * weights[i ];
121+ weighted_avg += depths[j ] * weights[j ];
129122 }
130123 weighted_avg /= max (weight_sum, FLT_MIN);
131124
132- // second pass: reject outliers and refine
133- float refined_sum = 0.0f ;
125+ float refined_sum = 0.0f ;
134126 float refined_weight = 0.0f ;
127+ float inv_avg = 1.0f / max (weighted_avg, 0.1f );
135128
136- for (int i = 0 ; i < FOCUS_SAMPLES; i++)
129+ [unroll]
130+ for (int k = 0 ; k < FOCUS_SAMPLES; k++)
137131 {
138- float deviation = abs (depths[i] - weighted_avg) / max (weighted_avg, 0.1f );
139-
132+ float deviation = abs (depths[k] - weighted_avg) * inv_avg;
140133 if (deviation < OUTLIER_THRESHOLD)
141134 {
142- // closer to average = higher contribution
143- float confidence = 1.0f - (deviation / OUTLIER_THRESHOLD);
144- confidence *= confidence; // square for sharper falloff
145-
146- float w = weights[i] * confidence;
147- refined_sum += depths[i] * w;
148- refined_weight += w;
135+ float confidence = 1.0f - deviation * INV_OUTLIER_THRES;
136+ confidence *= confidence;
137+ float w = weights[k] * confidence;
138+ refined_sum += depths[k] * w;
139+ refined_weight += w;
149140 }
150141 }
151142
152- // fallback if too many outliers
153143 return (refined_weight > FLT_MIN) ? (refined_sum / refined_weight) : weighted_avg;
154144}
155145
156- /*------------------------------------------------------------------------------
157- depth-aware sample weighting for bokeh blur
158-
159- handles the scatter/gather mismatch by simulating how out-of-focus
160- samples would "scatter" light to cover nearby pixels
161- ------------------------------------------------------------------------------*/
162- float sample_weight (float sample_coc, float center_coc, float sample_depth, float center_depth, float sample_distance)
163- {
164- // how much of the blur disk covers this sample position
165- float effective_coc = max (abs (sample_coc), abs (center_coc));
166- float coverage = saturate (1.0f - sample_distance / max (effective_coc, FLT_MIN));
167-
168- // depth-based occlusion: prevent background from bleeding into foreground
169- float depth_weight = 1.0f ;
170- bool sample_behind = sample_depth > center_depth;
171- bool center_is_fg = center_coc < 0.0f ; // negative coc = foreground
172-
173- if (sample_behind && center_is_fg)
174- {
175- // background sample trying to contribute to foreground pixel
176- depth_weight = BG_LEAK_PREVENTION;
177- }
178-
179- // larger coc samples contribute more (they scatter over larger area)
180- float scatter_factor = saturate (abs (sample_coc) / (MAX_COC_RADIUS * 0.3f ));
181- scatter_factor = lerp (0.3f , 1.0f , scatter_factor);
182-
183- return coverage * depth_weight * scatter_factor;
184- }
185-
186146/*------------------------------------------------------------------------------
187147 main bokeh blur with gather sampling
188148------------------------------------------------------------------------------*/
189- float3 bokeh_gather (float2 uv, float center_coc, float center_depth, float focus_dist, float aperture , float2 texel_size, float2 resolution)
149+ float3 bokeh_gather (float2 uv, float center_coc, float center_depth, lens_t lens , float2 texel_size, float2 resolution)
190150{
191151 float blur_radius = abs (center_coc);
152+
153+ float3 color_sum = tex.SampleLevel (samplers[sampler_bilinear_clamp], uv, 0 ).rgb;
154+ float weight_sum = 1.0f ;
192155
193- // early out for in-focus pixels
194- if (blur_radius < IN_FOCUS_THRESHOLD)
195- {
196- return tex.SampleLevel (samplers[sampler_bilinear_clamp], uv, 0 ).rgb;
197- }
198-
199- // accumulate samples
200- float3 color_sum = float3 (0.0f , 0.0f , 0.0f );
201- float weight_sum = 0.0f ;
202-
203- // center sample always included
204- float3 center_color = tex.SampleLevel (samplers[sampler_bilinear_clamp], uv, 0 ).rgb;
205- color_sum += center_color;
206- weight_sum += 1.0f ;
207-
156+ bool center_is_fg = center_coc < 0.0f ;
157+
208158 // randomize starting angle per pixel for temporal stability with taa
209- float start_angle = noise_interleaved_gradient (uv * resolution, true ) * PI2;
210- float angle = start_angle;
159+ float angle = noise_interleaved_gradient (uv * resolution, true ) * PI2;
211160
212- // golden angle spiral sampling
161+ [loop]
213162 for (int i = 0 ; i < SAMPLE_COUNT; i++)
214163 {
215164 angle += GOLDEN_ANGLE;
216165
217- // sqrt distribution: more samples near center
218- float t = (float )(i + 1 ) / (float )SAMPLE_COUNT;
166+ float t = (float )(i + 1 ) * INV_SAMPLE_COUNT;
219167 float r = sqrt (t) * blur_radius;
220168
221- float2 offset = float2 (cos (angle), sin (angle)) * r * texel_size;
169+ float sin_a, cos_a;
170+ sincos (angle, sin_a, cos_a);
171+ float2 offset = float2 (cos_a, sin_a) * r * texel_size;
222172 float2 sample_uv = uv + offset;
223173
224174 if (!is_valid_uv (sample_uv))
225175 continue ;
226176
227177 float3 sample_color = tex.SampleLevel (samplers[sampler_bilinear_clamp], sample_uv, 0 ).rgb;
228- float sample_depth = get_linear_depth (sample_uv * buffer_frame.resolution_scale);
229- float sample_coc = compute_coc (sample_depth, focus_dist, aperture, resolution);
230-
231- float w = sample_weight (sample_coc, center_coc, sample_depth, center_depth, r);
178+ float sample_depth = get_linear_depth (sample_uv * buffer_frame.resolution_scale);
179+ float sample_coc = compute_coc_signed (sample_depth, lens);
180+ float abs_sample_coc = abs (sample_coc);
181+
182+ // inlined sample_weight
183+ float effective_coc = max (abs_sample_coc, blur_radius);
184+ float coverage = saturate (1.0f - r / max (effective_coc, FLT_MIN));
185+ float depth_weight = (sample_depth > center_depth && center_is_fg) ? BG_LEAK_PREVENTION : 1.0f ;
186+ float scatter = lerp (0.3f , 1.0f , saturate (abs_sample_coc * INV_SCATTER_NORM));
187+ float w = coverage * depth_weight * scatter;
232188
233- color_sum += sample_color * w;
189+ color_sum += sample_color * w;
234190 weight_sum += w;
235191 }
236192
@@ -241,35 +197,49 @@ float3 bokeh_gather(float2 uv, float center_coc, float center_depth, float focus
241197 compute shader entry point
242198------------------------------------------------------------------------------*/
243199[numthreads (THREAD_GROUP_COUNT_X, THREAD_GROUP_COUNT_Y, 1 )]
244- void main_cs (uint3 thread_id : SV_DispatchThreadID )
200+ void main_cs (uint3 thread_id : SV_DispatchThreadID , uint group_index : SV_GroupIndex )
245201{
246202 float2 resolution;
247203 tex_uav.GetDimensions (resolution.x, resolution.y);
248-
204+
205+ // one thread per group computes the lens constants and shares them with the rest
206+ if (group_index == 0 )
207+ {
208+ float aperture_fstop = pass_get_f3_value ().x;
209+ float f = FOCAL_LENGTH_MM * 0.001f ;
210+ float aperture_diameter = f / max (aperture_fstop, 1.0f );
211+ float sensor_m = SENSOR_HEIGHT_MM * 0.001f ;
212+ float pixels_per_meter = resolution.y / sensor_m;
213+ float focus_distance = compute_focus_distance (resolution);
214+ float s = max (focus_distance, f + 0.01f );
215+
216+ gs_lens.focus_distance = s;
217+ gs_lens.coc_factor = (aperture_diameter * f * pixels_per_meter) / (abs (s - f) + FLT_MIN);
218+ }
219+ GroupMemoryBarrierWithGroupSync ();
220+
249221 if (any (thread_id.xy >= uint2 (resolution)))
250222 return ;
251-
252- float2 uv = (thread_id.xy + 0.5f ) / resolution;
223+
224+ lens_t lens = gs_lens;
225+ float2 uv = (thread_id.xy + 0.5f ) / resolution;
226+ float depth = get_linear_depth (uv * buffer_frame.resolution_scale);
227+ float coc = compute_coc_signed (depth, lens);
228+ float blur_radius = abs (coc);
229+
230+ // in-focus passthrough, smoothstep blend below 0.5 px is sub-perceptual so we skip the gather entirely
231+ if (blur_radius < IN_FOCUS_THRESHOLD)
232+ {
233+ tex_uav[thread_id.xy] = tex[thread_id.xy];
234+ return ;
235+ }
236+
253237 float2 texel_size = 1.0f / resolution;
254-
255- // get aperture from pass constants
256- float aperture = pass_get_f3_value ().x;
257-
258- // compute auto-focus distance
259- float focus_distance = compute_focus_distance (resolution);
260-
261- // compute coc for this pixel
262- float depth = get_linear_depth (uv * buffer_frame.resolution_scale);
263- float coc = compute_coc (depth, focus_distance, aperture, resolution);
264-
265- // perform depth-aware bokeh blur
266- float3 blurred = bokeh_gather (uv, coc, depth, focus_distance, aperture, texel_size, resolution);
267-
268- // blend based on blur amount
238+ float3 blurred = bokeh_gather (uv, coc, depth, lens, texel_size, resolution);
239+
269240 float4 original = tex[thread_id.xy];
270- float blend = smoothstep (0.0f , 1.0f , abs (coc) / MAX_COC_RADIUS);
271-
272- float3 result = lerp (original.rgb, blurred, blend);
273-
241+ float blend = smoothstep (0.0f , 1.0f , blur_radius / MAX_COC_RADIUS);
242+ float3 result = lerp (original.rgb, blurred, blend);
243+
274244 tex_uav[thread_id.xy] = float4 (result, original.a);
275245}
0 commit comments