Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions evaluations/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,49 @@ pub struct Args {
/// halves of the CI in the case of asymmetric CIs) <= precision_target.
#[arg(long = "adaptive-stopping-precision", value_parser = parse_precision_target, value_delimiter = ',', num_args = 0..)]
pub precision_targets: Vec<(String, f32)>,

/// Per-evaluator cutoff thresholds for pass/fail exit status.
/// Format: evaluator_name=cutoff, comma-separated for multiple evaluators.
/// Example: --cutoffs exact_match=0.95,llm_judge=0.8
/// If both this CLI flag and evaluator config `cutoff` are provided
/// for the same evaluator, the CLI value takes precedence.
#[arg(long, value_parser = parse_cutoff_target, value_delimiter = ',', num_args = 0..)]
pub cutoffs: Vec<(String, f32)>,
}

/// Parse a single precision target in format "evaluator_name=precision_target"
fn parse_precision_target(s: &str) -> Result<(String, f32), String> {
let s = s.trim();
if s.is_empty() {
return Err("Precision target cannot be empty".to_string());
parse_named_non_negative_float(s, "precision")
}

/// Parse a single cutoff target in format "evaluator_name=cutoff"
fn parse_cutoff_target(s: &str) -> Result<(String, f32), String> {
parse_named_non_negative_float(s, "cutoff")
}

fn parse_named_non_negative_float(input: &str, value_label: &str) -> Result<(String, f32), String> {
let input = input.trim();
if input.is_empty() {
return Err(format!("{value_label} cannot be empty"));
}

let parts: Vec<&str> = s.splitn(2, '=').collect();
let parts: Vec<&str> = input.splitn(2, '=').collect();
if parts.len() != 2 {
return Err(format!(
"Invalid precision format: `{s}`. Expected format: evaluator_name=precision_target"
"Invalid {value_label} format: `{input}`. Expected format: evaluator_name=value"
));
}

let evaluator_name = parts[0].to_string();
let precision_target = parts[1]
let value = parts[1]
.parse::<f32>()
.map_err(|e| format!("Invalid precision value `{}`: {e}", parts[1]))?;
.map_err(|e| format!("Invalid `{value_label}` value `{}`: {e}", parts[1]))?;

if precision_target < 0.0 {
if value < 0.0 {
return Err(format!(
"Precision value must be non-negative, got {precision_target}"
"{value_label} value must be non-negative, got {value}",
));
}

Ok((evaluator_name, precision_target))
Ok((evaluator_name, value))
}
8 changes: 8 additions & 0 deletions evaluations/src/evaluators/llm_judge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ mod tests {
let llm_judge_config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Serialized,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig::default(),
Expand Down Expand Up @@ -644,6 +645,7 @@ mod tests {
let llm_judge_config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Serialized,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand Down Expand Up @@ -858,6 +860,7 @@ mod tests {
let config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Serialized,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand Down Expand Up @@ -892,6 +895,7 @@ mod tests {
let config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Serialized,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand Down Expand Up @@ -989,6 +993,7 @@ mod tests {
let config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Messages,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand All @@ -1012,6 +1017,7 @@ mod tests {
let config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Messages,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand All @@ -1033,6 +1039,7 @@ mod tests {
let llm_judge_config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Messages,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig {
Expand Down Expand Up @@ -1147,6 +1154,7 @@ mod tests {
let llm_judge_config = LLMJudgeConfig {
input_format: LLMJudgeInputFormat::Serialized,
output_type: LLMJudgeOutputType::Float,
#[expect(deprecated)]
cutoff: None,
optimize: LLMJudgeOptimize::Max,
include: LLMJudgeIncludeConfig::default(),
Expand Down
184 changes: 164 additions & 20 deletions evaluations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ pub async fn run_evaluation(

// Convert Vec<(String, f32)> to HashMap<String, f32> for precision_targets
let precision_targets: HashMap<String, f32> = args.precision_targets.into_iter().collect();
let cli_cutoffs: HashMap<String, f32> = args.cutoffs.into_iter().collect();

let output_format = args.format.clone();
let result =
Expand Down Expand Up @@ -307,31 +308,37 @@ pub async fn run_evaluation(
progress_bar.finish_with_message("Done");
}

if evaluation_stats.output_format == OutputFormat::Pretty {
let EvaluationConfig::Inference(inference_evaluation_config) = &*result.evaluation_config;
let stats = evaluation_stats.compute_stats(&inference_evaluation_config.evaluators);
let EvaluationConfig::Inference(inference_evaluation_config) = &*result.evaluation_config;
let stats = evaluation_stats.compute_stats(&inference_evaluation_config.evaluators);

if evaluation_stats.output_format == OutputFormat::Pretty {
// Print all stats
for (evaluator_name, evaluator_stats) in &stats {
writeln!(writer, "{evaluator_name}: {evaluator_stats}")?;
}
}

// Check cutoffs and handle failures
let failures = check_evaluator_cutoffs(&stats, &inference_evaluation_config.evaluators)?;
// Check cutoffs and handle failures
let effective_cutoffs =
resolve_effective_cutoffs(&inference_evaluation_config.evaluators, &cli_cutoffs)?;
let failures = check_evaluator_cutoffs(
&stats,
&inference_evaluation_config.evaluators,
&effective_cutoffs,
)?;

// Print failure messages
if !failures.is_empty() {
for (name, cutoff, actual) in &failures {
writeln!(
writer,
"Failed cutoff for evaluator {name} ({cutoff:.2}, got {actual:.2})"
)?;
tracing::warn!(
evaluator = %name,
cutoff = cutoff,
actual = actual,
"Failed cutoff for evaluator `{name}` ({cutoff:.2}, got {actual:.2})"
);
}

// If there are failures, return an error with all failures listed
if !failures.is_empty() {
let failure_messages = format_cutoff_failures(&failures);
bail!("Failed cutoffs for evaluators: {failure_messages}");
}
let failure_messages = format_cutoff_failures(&failures);
bail!("Failed cutoffs for evaluators: {failure_messages}");
}

// Since we construct our own `ClickHouseConnectionInfo` outside of our `TensorZeroClient`,
Expand Down Expand Up @@ -741,12 +748,70 @@ pub async fn run_evaluation_core_streaming(
})
}

/// Resolves cutoff thresholds for pass/fail checks during migration.
///
/// Sources:
/// - legacy config cutoff (`EvaluatorConfig::cutoff()`)
/// - CLI cutoff (`--cutoffs`)
///
/// Precedence:
/// - If both are present for the same evaluator, CLI value wins and a warning is emitted.
/// - If only one source is present, that value is used.
///
/// TODO(#6603): Delete this after fully removing cutoffs in evaluator configs.
pub fn resolve_effective_cutoffs(
evaluator_configs: &HashMap<String, EvaluatorConfig>,
cli_cutoffs: &HashMap<String, f32>,
) -> Result<HashMap<String, f32>> {
for evaluator_name in cli_cutoffs.keys() {
if !evaluator_configs.contains_key(evaluator_name) {
return Err(anyhow!(
"Unknown evaluator in --cutoffs: `{evaluator_name}`"
));
}
}

let mut effective_cutoffs = HashMap::new();

for (evaluator_name, evaluator_config) in evaluator_configs {
let config_cutoff = evaluator_config.cutoff();
let cli_cutoff = cli_cutoffs.get(evaluator_name).copied();

match (config_cutoff, cli_cutoff) {
(Some(config_value), Some(cli_value)) => {
tracing::warn!(
evaluator_name = %evaluator_name,
config_cutoff = config_value,
cli_cutoff = cli_value,
"Evaluator config `cutoff` is deprecated; please remove `cutoff` from config. Using CLI `--cutoffs` value."
);
effective_cutoffs.insert(evaluator_name.clone(), cli_value);
}
(Some(config_value), None) => {
tracing::warn!(
evaluator_name = %evaluator_name,
config_cutoff = config_value,
"Evaluator config `cutoff` is deprecated; please remove `cutoff` from config and pass `--cutoffs` on the CLI instead."
);
effective_cutoffs.insert(evaluator_name.clone(), config_value);
}
(None, Some(cutoff)) => {
effective_cutoffs.insert(evaluator_name.clone(), cutoff);
}
(None, None) => {}
}
}

Ok(effective_cutoffs)
}

/// Checks if evaluator results meet their cutoff thresholds
///
/// Returns a vector of failures with (evaluator_name, cutoff, actual_value)
pub fn check_evaluator_cutoffs(
stats: &HashMap<String, stats::EvaluatorStats>,
evaluator_configs: &HashMap<String, EvaluatorConfig>,
effective_cutoffs: &HashMap<String, f32>,
) -> Result<Vec<(String, f32, f32)>> {
let mut failures = Vec::new();

Expand All @@ -755,7 +820,7 @@ pub fn check_evaluator_cutoffs(
.get(evaluator_name)
.ok_or_else(|| anyhow!("Evaluator not found for computing stats"))?;

if let Some(cutoff) = evaluator_config.cutoff() {
if let Some(cutoff) = effective_cutoffs.get(evaluator_name).copied() {
match evaluator_config.optimize() {
MetricConfigOptimize::Max => {
if evaluator_stats.mean < cutoff {
Expand Down Expand Up @@ -1250,27 +1315,106 @@ mod tests {
let mut evaluators = HashMap::new();
evaluators.insert(
"evaluator1".to_string(),
#[expect(deprecated)]
EvaluatorConfig::ExactMatch(ExactMatchConfig { cutoff: Some(0.5) }),
);
evaluators.insert(
"evaluator2".to_string(),
#[expect(deprecated)]
EvaluatorConfig::ExactMatch(ExactMatchConfig { cutoff: Some(0.6) }),
);
evaluators.insert(
"evaluator3".to_string(),
#[expect(deprecated)]
EvaluatorConfig::ExactMatch(ExactMatchConfig { cutoff: None }),
);
evaluators
};
let failures = check_evaluator_cutoffs(&stats, &evaluators).unwrap();

let cli_cutoffs = HashMap::new();
let effective_cutoffs = resolve_effective_cutoffs(&evaluators, &cli_cutoffs)
.expect("cutoff resolution should succeed");

let failures = check_evaluator_cutoffs(&stats, &evaluators, &effective_cutoffs)
.expect("cutoff checks should succeed");
assert_eq!(failures.len(), 2);

// Check that both expected failures are present, regardless of order
assert!(failures.contains(&("evaluator1".to_string(), 0.5, 0.4)));
assert!(failures.contains(&("evaluator2".to_string(), 0.6, 0.3)));
assert!(
failures.contains(&("evaluator1".to_string(), 0.5, 0.4)),
"should include evaluator1 failure when mean is below cutoff"
);
assert!(
failures.contains(&("evaluator2".to_string(), 0.6, 0.3)),
"should include evaluator2 failure when mean is below cutoff"
);

// Check that evaluator3 is not in the failures list since it has no cutoff
assert!(!failures.iter().any(|(name, _, _)| name == "evaluator3"));
assert!(
!failures.iter().any(|(name, _, _)| name == "evaluator3"),
"should not include evaluator3 when no cutoff is configured"
);
}

#[test]
fn test_resolve_effective_cutoffs_cli_takes_precedence() {
let mut evaluators = HashMap::new();
evaluators.insert(
"evaluator1".to_string(),
#[expect(deprecated)]
EvaluatorConfig::ExactMatch(ExactMatchConfig { cutoff: Some(0.5) }),
);
evaluators.insert(
"evaluator2".to_string(),
#[expect(deprecated)]
EvaluatorConfig::ExactMatch(ExactMatchConfig { cutoff: Some(0.6) }),
);

let cli_cutoffs = HashMap::from([
("evaluator1".to_string(), 0.8),
("evaluator3".to_string(), 0.3),
]);

let err = resolve_effective_cutoffs(&evaluators, &cli_cutoffs)
.expect_err("unknown evaluator in CLI cutoffs should error");
assert!(
err.to_string().contains("`evaluator3`"),
"should error on unknown evaluator cutoff names"
);

let cli_cutoffs = HashMap::from([("evaluator1".to_string(), 0.8)]);
let effective = resolve_effective_cutoffs(&evaluators, &cli_cutoffs)
.expect("cutoff resolution should succeed when evaluator names are valid");

assert_eq!(
effective.get("evaluator1"),
Some(&0.8),
"CLI cutoff should override config cutoff for the same evaluator"
);
assert_eq!(
effective.get("evaluator2"),
Some(&0.6),
"config cutoff should still apply when no CLI cutoff is provided"
);
}

#[test]
fn test_resolve_effective_cutoffs_unknown_evaluator() {
let evaluators = HashMap::from([(
"evaluator1".to_string(),
EvaluatorConfig::ExactMatch(ExactMatchConfig::default()),
)]);
let cli_cutoffs = HashMap::from([
("nonexistent".to_string(), 0.5),
("evaluator1".to_string(), 0.8),
]);

let err = resolve_effective_cutoffs(&evaluators, &cli_cutoffs)
.expect_err("should error when CLI cutoff references unknown evaluator");
assert!(
err.to_string().contains("`nonexistent`"),
"error should name the unknown evaluator, got: {err}"
);
}

#[test]
Expand Down
Loading
Loading