Skip to content

Commit ebdd5b5

Browse files
Merge pull request #72 from saksham-jain177/enhancement/estimate-cost
Added offline token cost estimation with hard-isolated execution path
2 parents 112c890 + 67e4737 commit ebdd5b5

File tree

7 files changed

+303
-46
lines changed

7 files changed

+303
-46
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ dependencies = [
3434
"isort==5.13.2",
3535
"tomli==2.2.1",
3636
"claude-agent-sdk>=0.1.0",
37+
"tiktoken==0.12.0",
38+
"genai-prices==0.0.51",
3739
]
3840
classifiers = [
3941
"Development Status :: 5 - Production/Stable",

python_gpt_po/main.py

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
import traceback
1212
from argparse import Namespace
1313
from dataclasses import dataclass
14-
from typing import Dict, List, Optional
14+
from typing import Any, Dict, List, Optional, Tuple
1515

1616
from .models.config import TranslationConfig, TranslationFlags
17-
from .models.enums import ModelProvider
1817
from .models.provider_clients import ProviderClients
1918
from .services.language_detector import LanguageDetector
2019
from .services.model_manager import ModelManager
2120
from .services.translation_service import TranslationService
2221
from .utils.cli import (auto_select_provider, create_language_mapping, get_provider_from_args, parse_args,
2322
show_help_and_exit, validate_provider_key)
2423
from .utils.config_loader import ConfigLoader
24+
from .utils.cost_estimator import CostEstimator
2525

2626

2727
def setup_logging(verbose: int = 0, quiet: bool = False):
@@ -53,20 +53,11 @@ def setup_logging(verbose: int = 0, quiet: bool = False):
5353
logging.getLogger().setLevel(level)
5454

5555

56-
def initialize_provider(args: Namespace) -> tuple[ProviderClients, ModelProvider, str]:
56+
def get_offline_provider_info(args: Namespace) -> Tuple[Any, Any, str]:
5757
"""
58-
Initialize the provider client and determine the appropriate model.
59-
60-
Args:
61-
args: Command line arguments from argparse
62-
63-
Returns:
64-
tuple: (provider_clients, provider, model)
65-
66-
Raises:
67-
SystemExit: If no valid provider can be found or initialized
58+
Get provider and model information without making network calls.
6859
"""
69-
# Initialize provider clients
60+
# Initialize provider clients (reads environment variables and args)
7061
provider_clients = ProviderClients()
7162
api_keys = provider_clients.initialize_clients(args)
7263

@@ -82,40 +73,43 @@ def initialize_provider(args: Namespace) -> tuple[ProviderClients, ModelProvider
8273
if not validate_provider_key(provider, api_keys):
8374
sys.exit(1)
8475

76+
# Determine model - use CLI arg or default
77+
model = args.model
78+
if not model:
79+
model = ModelManager.get_default_model(provider)
80+
81+
return provider_clients, provider, model
82+
83+
84+
def initialize_provider(args: Namespace, provider_clients: Any, provider: Any, model: str) -> Tuple[Any, Any, str]:
85+
"""
86+
Finalize provider initialization with network validation if needed.
87+
"""
8588
# Create model manager for model operations
8689
model_manager = ModelManager()
8790

88-
# List models if requested and exit
91+
# List models if requested and exit (this makes network calls)
8992
if args.list_models:
9093
models = model_manager.get_available_models(provider_clients, provider)
9194
print(f"Available models for {provider.value}:")
92-
for model in models:
93-
print(f" - {model}")
95+
for m in models:
96+
print(f" - {m}")
9497
sys.exit(0)
9598

96-
# Determine appropriate model
97-
model = get_appropriate_model(provider, provider_clients, model_manager, args.model)
99+
# Validate model (this makes network calls)
100+
final_model = get_appropriate_model(provider, provider_clients, model_manager, model)
98101

99-
return provider_clients, provider, model
102+
return provider_clients, provider, final_model
100103

101104

102105
def get_appropriate_model(
103-
provider: ModelProvider,
104-
provider_clients: ProviderClients,
105-
model_manager: ModelManager,
106+
provider: Any,
107+
provider_clients: Any,
108+
model_manager: Any,
106109
requested_model: Optional[str]
107110
) -> str:
108111
"""
109112
Get the appropriate model for the provider.
110-
111-
Args:
112-
provider (ModelProvider): The selected provider
113-
provider_clients (ProviderClients): The initialized provider clients
114-
model_manager (ModelManager): The model manager instance
115-
requested_model (Optional[str]): Model requested by the user
116-
117-
Returns:
118-
str: The appropriate model ID
119113
"""
120114
# If a specific model was requested, validate it
121115
if requested_model:
@@ -143,7 +137,7 @@ def get_appropriate_model(
143137
@dataclass
144138
class TranslationTask:
145139
"""Parameters for translation processing."""
146-
config: TranslationConfig
140+
config: Any
147141
folder: str
148142
languages: List[str]
149143
detail_languages: Dict[str, str]
@@ -154,9 +148,6 @@ class TranslationTask:
154148
def process_translations(task: TranslationTask):
155149
"""
156150
Process translations for the given task parameters.
157-
158-
Args:
159-
task: TranslationTask containing all processing parameters
160151
"""
161152
# Initialize translation service
162153
translation_service = TranslationService(task.config, task.batch_size)
@@ -192,12 +183,9 @@ def main():
192183
setup_logging(verbose=args.verbose, quiet=args.quiet)
193184

194185
try:
195-
# Initialize provider
196-
provider_clients, provider, model = initialize_provider(args)
197-
198-
# Get languages - either from args or auto-detect from PO files
186+
# 1. Get languages (Pure logic)
199187
try:
200-
respect_gitignore = not args.no_gitignore # Invert the flag
188+
respect_gitignore = not args.no_gitignore
201189
languages = LanguageDetector.validate_or_detect_languages(
202190
folder=args.folder,
203191
lang_arg=args.lang,
@@ -208,7 +196,63 @@ def main():
208196
logging.error(str(e))
209197
sys.exit(1)
210198

211-
# Create mapping between language codes and detailed names
199+
# 2. Extract model name for offline estimation (Purely offline)
200+
# Defaults to gpt-4o-mini if not specified. Avoids ModelManager to prevent early side-effects.
201+
estimated_model = args.model or "gpt-4o-mini"
202+
203+
# 3. Estimate cost if requested (Strictly Offline Terminal Flow)
204+
if args.estimate_cost:
205+
estimation = CostEstimator.estimate_cost(
206+
args.folder,
207+
languages,
208+
estimated_model,
209+
fix_fuzzy=args.fix_fuzzy,
210+
respect_gitignore=respect_gitignore
211+
)
212+
213+
print(f"\n{'=' * 40}")
214+
print(" OFFLINE TOKEN ESTIMATION REPORT")
215+
print(f"{'=' * 40}")
216+
print(f"Model: {estimation['model']}")
217+
print(f"Rate: {estimation['rate_info']}")
218+
print(f"Unique msgids: {estimation['unique_texts']:,}")
219+
print(f"Total Tokens: {estimation['total_tokens']:,} (estimated expansion included)")
220+
221+
if estimation['estimated_cost'] is not None:
222+
print(f"Estimated Cost: ${estimation['estimated_cost']:.4f}")
223+
224+
print("\nPer-language Breakdown:")
225+
for lang, data in estimation['breakdown'].items():
226+
cost_str = f"${data['cost']:.4f}" if data['cost'] is not None else "unavailable"
227+
print(f" - {lang:5}: {data['tokens']:8,} tokens | {cost_str}")
228+
229+
print("\nNote: Cost estimates are approximate and may not reflect current provider pricing.")
230+
print(f"{'=' * 40}\n")
231+
232+
if estimation['total_tokens'] == 0:
233+
logging.info("No entries require translation.")
234+
return
235+
236+
if not args.yes:
237+
confirm = input("Run actual translation with these settings? (y/n): ").lower()
238+
if confirm != 'y':
239+
logging.info("Cancelled by user.")
240+
return
241+
242+
# Issue #57: Hard exit after estimation to ensure zero side effects.
243+
# Estimation is a terminal dry-run. This prevents "Registered provider" logs
244+
# or connection attempts from leaking into the audit output.
245+
print(
246+
"\n[Audit Successful] To proceed with actual translation, "
247+
"run the command again WITHOUT --estimate-cost."
248+
)
249+
return
250+
251+
# 4. Initialize providers (Online Execution Path Starts Here)
252+
provider_clients, provider, final_model_id = get_offline_provider_info(args)
253+
provider_clients, provider, model = initialize_provider(args, provider_clients, provider, final_model_id)
254+
255+
# 5. Create mapping between language codes and detailed names
212256
try:
213257
detail_languages = create_language_mapping(languages, args.detail_lang)
214258
except ValueError as e:

python_gpt_po/services/po_file_handler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,20 @@ def get_file_language(po_file_path, po_file, languages, folder_language):
125125

126126
if folder_language:
127127
for part in po_file_path.split(os.sep):
128-
# Try variants of the folder part
129-
variant_match = POFileHandler._try_language_variants(part, languages)
128+
# Clean part (strip .po if it's the filename)
129+
clean_part = part
130+
if part.endswith('.po'):
131+
clean_part = part[:-3]
132+
133+
# Try variants of the folder/file part
134+
variant_match = POFileHandler._try_language_variants(clean_part, languages)
130135
if variant_match:
131136
logging.info("Inferred language for .po file: %s as %s", po_file_path, variant_match)
132137
return variant_match
133138

134139
# Try base language fallback
135-
if not POFileHandler._should_skip_fallback(part):
136-
norm_part = POFileHandler.normalize_language_code(part)
140+
if not POFileHandler._should_skip_fallback(clean_part):
141+
norm_part = POFileHandler.normalize_language_code(clean_part)
137142
if norm_part and norm_part in languages:
138143
logging.info("Inferred language for .po file: %s as %s (base of %s)",
139144
po_file_path, norm_part, part)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import shutil
3+
import unittest
4+
5+
import polib
6+
7+
from python_gpt_po.utils.cost_estimator import CostEstimator
8+
9+
10+
class TestCostEstimatorMinimal(unittest.TestCase):
11+
def setUp(self):
12+
self.test_dir = os.path.abspath("test_cost_est_minimal")
13+
if os.path.exists(self.test_dir):
14+
shutil.rmtree(self.test_dir)
15+
os.makedirs(self.test_dir)
16+
17+
def tearDown(self):
18+
if os.path.exists(self.test_dir):
19+
shutil.rmtree(self.test_dir)
20+
21+
def test_minimal_token_math(self):
22+
"""Verify tokenize once and multiply by languages."""
23+
po_path = os.path.join(self.test_dir, "test.po")
24+
po = polib.POFile()
25+
# "Hello" is approx 1-2 tokens.
26+
po.append(polib.POEntry(msgid="Hello", msgstr=""))
27+
po.save(po_path)
28+
29+
# 1 language
30+
est1 = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini")
31+
t1 = est1['total_tokens']
32+
33+
# 3 languages
34+
est3 = CostEstimator.estimate_cost(self.test_dir, ["fr", "es", "de"], "gpt-4o-mini")
35+
t3 = est3['total_tokens']
36+
37+
self.assertEqual(t3, t1 * 3)
38+
39+
def test_pricing_lookup(self):
40+
"""Verify dynamic pricing lookup via genai-prices."""
41+
po_path = os.path.join(self.test_dir, "test.po")
42+
po = polib.POFile()
43+
po.append(polib.POEntry(msgid="Test", msgstr=""))
44+
po.save(po_path)
45+
46+
# Known model
47+
est_known = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini")
48+
self.assertIsNotNone(est_known['estimated_cost'])
49+
50+
# Unknown model
51+
est_unknown = CostEstimator.estimate_cost(self.test_dir, ["fr"], "unknown-model")
52+
self.assertIsNone(est_unknown['estimated_cost'])
53+
54+
def test_zero_work(self):
55+
"""Verify zero tokens when everything is translated."""
56+
po_path = os.path.join(self.test_dir, "test.po")
57+
po = polib.POFile()
58+
po.append(polib.POEntry(msgid="Hello", msgstr="Bonjour"))
59+
po.save(po_path)
60+
61+
est = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini")
62+
self.assertEqual(est['total_tokens'], 0)
63+
64+
65+
if __name__ == '__main__':
66+
unittest.main()

python_gpt_po/utils/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@ def parse_args() -> Namespace:
186186
metavar="SIZE",
187187
help="Number of strings to translate in each batch (default: 50)"
188188
)
189+
advanced_group.add_argument(
190+
"--estimate-cost",
191+
action="store_true",
192+
help="Estimate token usage and cost before translating"
193+
)
194+
advanced_group.add_argument(
195+
"-y", "--yes",
196+
action="store_true",
197+
help="Skip confirmation prompt when using --estimate-cost"
198+
)
189199
fuzzy_group.add_argument(
190200
"--fuzzy",
191201
action="store_true",

0 commit comments

Comments
 (0)