| | """ |
| | Extraction Validation |
| | |
| | Validates extracted data and provides confidence scoring. |
| | """ |
| |
|
| | import logging |
| | from dataclasses import dataclass, field |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from ..chunks.models import ( |
| | ExtractionResult, |
| | FieldExtraction, |
| | ConfidenceLevel, |
| | ) |
| | from .schema import ExtractionSchema, FieldSpec, FieldType |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class ValidationIssue: |
| | """A validation issue found during extraction validation.""" |
| |
|
| | field_name: str |
| | issue_type: str |
| | message: str |
| | severity: str = "warning" |
| | suggested_action: Optional[str] = None |
| |
|
| |
|
| | @dataclass |
| | class ValidationResult: |
| | """Result of extraction validation.""" |
| |
|
| | is_valid: bool |
| | issues: List[ValidationIssue] = field(default_factory=list) |
| | confidence_score: float = 0.0 |
| | field_scores: Dict[str, float] = field(default_factory=dict) |
| | recommendations: List[str] = field(default_factory=list) |
| |
|
| | @property |
| | def error_count(self) -> int: |
| | return sum(1 for i in self.issues if i.severity == "error") |
| |
|
| | @property |
| | def warning_count(self) -> int: |
| | return sum(1 for i in self.issues if i.severity == "warning") |
| |
|
| | def get_issues_for_field(self, field_name: str) -> List[ValidationIssue]: |
| | """Get all issues for a specific field.""" |
| | return [i for i in self.issues if i.field_name == field_name] |
| |
|
| |
|
| | class ExtractionValidator: |
| | """ |
| | Validates extraction results against schemas. |
| | |
| | Checks for: |
| | - Required field presence |
| | - Type correctness |
| | - Value constraints |
| | - Confidence thresholds |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | min_confidence: float = 0.5, |
| | strict_mode: bool = False, |
| | ): |
| | self.min_confidence = min_confidence |
| | self.strict_mode = strict_mode |
| |
|
| | def validate( |
| | self, |
| | extraction: ExtractionResult, |
| | schema: ExtractionSchema, |
| | ) -> ValidationResult: |
| | """ |
| | Validate extraction result against schema. |
| | |
| | Args: |
| | extraction: Extraction result to validate |
| | schema: Schema defining expected fields |
| | |
| | Returns: |
| | ValidationResult with issues and scores |
| | """ |
| | issues: List[ValidationIssue] = [] |
| | field_scores: Dict[str, float] = {} |
| |
|
| | |
| | for field_spec in schema.fields: |
| | field_issues, score = self._validate_field( |
| | field_spec=field_spec, |
| | extraction=extraction, |
| | ) |
| | issues.extend(field_issues) |
| | field_scores[field_spec.name] = score |
| |
|
| | |
| | expected_fields = {f.name for f in schema.fields} |
| | for field_name in extraction.data.keys(): |
| | if field_name not in expected_fields: |
| | issues.append(ValidationIssue( |
| | field_name=field_name, |
| | issue_type="unexpected", |
| | message=f"Unexpected field: {field_name}", |
| | severity="info", |
| | )) |
| |
|
| | |
| | if field_scores: |
| | confidence_score = sum(field_scores.values()) / len(field_scores) |
| | else: |
| | confidence_score = 0.0 |
| |
|
| | |
| | is_valid = ( |
| | all(i.severity != "error" for i in issues) and |
| | confidence_score >= schema.min_overall_confidence |
| | ) |
| |
|
| | |
| | recommendations = self._generate_recommendations(issues, extraction) |
| |
|
| | return ValidationResult( |
| | is_valid=is_valid, |
| | issues=issues, |
| | confidence_score=confidence_score, |
| | field_scores=field_scores, |
| | recommendations=recommendations, |
| | ) |
| |
|
| | def _validate_field( |
| | self, |
| | field_spec: FieldSpec, |
| | extraction: ExtractionResult, |
| | ) -> Tuple[List[ValidationIssue], float]: |
| | """Validate a single field.""" |
| | issues: List[ValidationIssue] = [] |
| | score = 1.0 |
| |
|
| | value = extraction.data.get(field_spec.name) |
| | field_extraction = self._get_field_extraction(field_spec.name, extraction) |
| |
|
| | |
| | if value is None: |
| | if field_spec.required: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="missing", |
| | message=f"Required field '{field_spec.name}' is missing", |
| | severity="error", |
| | suggested_action="Manual review required", |
| | )) |
| | return issues, 0.0 |
| | else: |
| | return issues, 1.0 |
| |
|
| | |
| | if field_spec.name in extraction.abstained_fields: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="abstained", |
| | message=f"Field '{field_spec.name}' was abstained due to low confidence", |
| | severity="warning", |
| | suggested_action="Manual verification recommended", |
| | )) |
| | score *= 0.5 |
| |
|
| | |
| | if field_extraction: |
| | if field_extraction.confidence < self.min_confidence: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="low_confidence", |
| | message=f"Field '{field_spec.name}' has low confidence: {field_extraction.confidence:.2f}", |
| | severity="warning", |
| | suggested_action="Manual verification recommended", |
| | )) |
| | score *= field_extraction.confidence |
| | else: |
| | score *= field_extraction.confidence |
| |
|
| | |
| | type_issues = self._validate_type(field_spec, value) |
| | issues.extend(type_issues) |
| | if type_issues: |
| | score *= 0.7 |
| |
|
| | |
| | constraint_issues = self._validate_constraints(field_spec, value) |
| | issues.extend(constraint_issues) |
| | if constraint_issues: |
| | score *= 0.8 |
| |
|
| | return issues, max(0.0, min(1.0, score)) |
| |
|
| | def _validate_type( |
| | self, |
| | field_spec: FieldSpec, |
| | value: Any, |
| | ) -> List[ValidationIssue]: |
| | """Validate field type.""" |
| | issues = [] |
| |
|
| | expected_type = self._get_expected_python_type(field_spec.field_type) |
| |
|
| | if expected_type and not isinstance(value, expected_type): |
| | |
| | try: |
| | expected_type(value) |
| | except (ValueError, TypeError): |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="type_mismatch", |
| | message=f"Field '{field_spec.name}' expected {field_spec.field_type.value}, got {type(value).__name__}", |
| | severity="warning" if not self.strict_mode else "error", |
| | )) |
| |
|
| | return issues |
| |
|
| | def _validate_constraints( |
| | self, |
| | field_spec: FieldSpec, |
| | value: Any, |
| | ) -> List[ValidationIssue]: |
| | """Validate field constraints.""" |
| | issues = [] |
| |
|
| | |
| | if field_spec.pattern: |
| | import re |
| | if not re.match(field_spec.pattern, str(value)): |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="pattern_mismatch", |
| | message=f"Field '{field_spec.name}' does not match pattern: {field_spec.pattern}", |
| | severity="warning", |
| | )) |
| |
|
| | |
| | try: |
| | num_value = float(value) |
| | if field_spec.min_value is not None and num_value < field_spec.min_value: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="below_minimum", |
| | message=f"Field '{field_spec.name}' value {num_value} is below minimum {field_spec.min_value}", |
| | severity="warning", |
| | )) |
| | if field_spec.max_value is not None and num_value > field_spec.max_value: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="above_maximum", |
| | message=f"Field '{field_spec.name}' value {num_value} is above maximum {field_spec.max_value}", |
| | severity="warning", |
| | )) |
| | except (ValueError, TypeError): |
| | pass |
| |
|
| | |
| | str_value = str(value) |
| | if field_spec.min_length is not None and len(str_value) < field_spec.min_length: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="too_short", |
| | message=f"Field '{field_spec.name}' is too short: {len(str_value)} < {field_spec.min_length}", |
| | severity="warning", |
| | )) |
| | if field_spec.max_length is not None and len(str_value) > field_spec.max_length: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="too_long", |
| | message=f"Field '{field_spec.name}' is too long: {len(str_value)} > {field_spec.max_length}", |
| | severity="warning", |
| | )) |
| |
|
| | |
| | if field_spec.allowed_values and value not in field_spec.allowed_values: |
| | issues.append(ValidationIssue( |
| | field_name=field_spec.name, |
| | issue_type="not_in_allowed", |
| | message=f"Field '{field_spec.name}' value '{value}' not in allowed values", |
| | severity="warning", |
| | )) |
| |
|
| | return issues |
| |
|
| | def _get_field_extraction( |
| | self, |
| | field_name: str, |
| | extraction: ExtractionResult, |
| | ) -> Optional[FieldExtraction]: |
| | """Get field extraction by name.""" |
| | for fe in extraction.fields: |
| | if fe.field_name == field_name: |
| | return fe |
| | return None |
| |
|
| | def _get_expected_python_type(self, field_type: FieldType) -> Optional[type]: |
| | """Get expected Python type for field type.""" |
| | type_map = { |
| | FieldType.INTEGER: int, |
| | FieldType.FLOAT: float, |
| | FieldType.BOOLEAN: bool, |
| | FieldType.LIST: list, |
| | FieldType.OBJECT: dict, |
| | } |
| | return type_map.get(field_type) |
| |
|
| | def _generate_recommendations( |
| | self, |
| | issues: List[ValidationIssue], |
| | extraction: ExtractionResult, |
| | ) -> List[str]: |
| | """Generate recommendations based on issues.""" |
| | recommendations = [] |
| |
|
| | |
| | missing_count = sum(1 for i in issues if i.issue_type == "missing") |
| | low_conf_count = sum(1 for i in issues if i.issue_type == "low_confidence") |
| | type_count = sum(1 for i in issues if i.issue_type == "type_mismatch") |
| |
|
| | if missing_count > 0: |
| | recommendations.append( |
| | f"Review document for {missing_count} missing required field(s)" |
| | ) |
| |
|
| | if low_conf_count > 0: |
| | recommendations.append( |
| | f"Manual verification recommended for {low_conf_count} low-confidence field(s)" |
| | ) |
| |
|
| | if type_count > 0: |
| | recommendations.append( |
| | f"Check data types for {type_count} field(s) with type mismatches" |
| | ) |
| |
|
| | if extraction.overall_confidence < 0.5: |
| | recommendations.append( |
| | "Overall extraction confidence is low - consider manual review" |
| | ) |
| |
|
| | if len(extraction.abstained_fields) > 0: |
| | recommendations.append( |
| | f"System abstained on {len(extraction.abstained_fields)} field(s) due to uncertainty" |
| | ) |
| |
|
| | return recommendations |
| |
|
| |
|
| | class CrossFieldValidator: |
| | """ |
| | Validates relationships between fields. |
| | |
| | Checks for: |
| | - Consistency (e.g., subtotal + tax = total) |
| | - Logical relationships |
| | - Date ordering |
| | """ |
| |
|
| | def validate_consistency( |
| | self, |
| | extraction: ExtractionResult, |
| | rules: List[Dict[str, Any]], |
| | ) -> List[ValidationIssue]: |
| | """ |
| | Validate cross-field consistency rules. |
| | |
| | Rules format: |
| | { |
| | "type": "sum", |
| | "fields": ["subtotal", "tax"], |
| | "equals": "total", |
| | "tolerance": 0.01 |
| | } |
| | """ |
| | issues = [] |
| |
|
| | for rule in rules: |
| | rule_type = rule.get("type") |
| |
|
| | if rule_type == "sum": |
| | issue = self._validate_sum_rule(extraction, rule) |
| | if issue: |
| | issues.append(issue) |
| |
|
| | elif rule_type == "date_order": |
| | issue = self._validate_date_order(extraction, rule) |
| | if issue: |
| | issues.append(issue) |
| |
|
| | elif rule_type == "required_if": |
| | issue = self._validate_required_if(extraction, rule) |
| | if issue: |
| | issues.append(issue) |
| |
|
| | return issues |
| |
|
| | def _validate_sum_rule( |
| | self, |
| | extraction: ExtractionResult, |
| | rule: Dict[str, Any], |
| | ) -> Optional[ValidationIssue]: |
| | """Validate that sum of fields equals another field.""" |
| | fields = rule.get("fields", []) |
| | equals_field = rule.get("equals") |
| | tolerance = rule.get("tolerance", 0.01) |
| |
|
| | try: |
| | sum_value = sum( |
| | float(extraction.data.get(f, 0) or 0) |
| | for f in fields |
| | ) |
| | expected = float(extraction.data.get(equals_field, 0) or 0) |
| |
|
| | if abs(sum_value - expected) > tolerance: |
| | return ValidationIssue( |
| | field_name=equals_field, |
| | issue_type="sum_mismatch", |
| | message=f"Sum of {fields} ({sum_value}) does not equal {equals_field} ({expected})", |
| | severity="warning", |
| | ) |
| | except (ValueError, TypeError): |
| | pass |
| |
|
| | return None |
| |
|
| | def _validate_date_order( |
| | self, |
| | extraction: ExtractionResult, |
| | rule: Dict[str, Any], |
| | ) -> Optional[ValidationIssue]: |
| | """Validate that dates are in correct order.""" |
| | from datetime import datetime |
| |
|
| | before_field = rule.get("before") |
| | after_field = rule.get("after") |
| |
|
| | before_val = extraction.data.get(before_field) |
| | after_val = extraction.data.get(after_field) |
| |
|
| | if not before_val or not after_val: |
| | return None |
| |
|
| | try: |
| | |
| | formats = ["%Y-%m-%d", "%m/%d/%Y", "%d/%m/%Y", "%B %d, %Y"] |
| |
|
| | before_date = None |
| | after_date = None |
| |
|
| | for fmt in formats: |
| | try: |
| | before_date = datetime.strptime(str(before_val), fmt) |
| | break |
| | except ValueError: |
| | continue |
| |
|
| | for fmt in formats: |
| | try: |
| | after_date = datetime.strptime(str(after_val), fmt) |
| | break |
| | except ValueError: |
| | continue |
| |
|
| | if before_date and after_date and before_date > after_date: |
| | return ValidationIssue( |
| | field_name=after_field, |
| | issue_type="date_order", |
| | message=f"Date {before_field} ({before_val}) should be before {after_field} ({after_val})", |
| | severity="warning", |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | return None |
| |
|
| | def _validate_required_if( |
| | self, |
| | extraction: ExtractionResult, |
| | rule: Dict[str, Any], |
| | ) -> Optional[ValidationIssue]: |
| | """Validate conditional required fields.""" |
| | field = rule.get("field") |
| | required_if = rule.get("required_if") |
| | condition_value = rule.get("value") |
| |
|
| | condition_field_value = extraction.data.get(required_if) |
| |
|
| | |
| | condition_met = False |
| | if condition_value is not None: |
| | condition_met = condition_field_value == condition_value |
| | else: |
| | condition_met = condition_field_value is not None |
| |
|
| | if condition_met: |
| | field_value = extraction.data.get(field) |
| | if field_value is None: |
| | return ValidationIssue( |
| | field_name=field, |
| | issue_type="conditional_required", |
| | message=f"Field '{field}' is required when '{required_if}' is present", |
| | severity="warning", |
| | ) |
| |
|
| | return None |
| |
|