git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
854 lines
31 KiB
Rust
854 lines
31 KiB
Rust
//! Type Checker for HoTT
|
|
//!
|
|
//! Implements bidirectional type checking with:
|
|
//! - Type inference (synthesis)
|
|
//! - Type checking
|
|
//! - Normalization (beta reduction)
|
|
//! - Context management
|
|
|
|
use std::collections::HashMap;
|
|
use super::{Type, Term, TypeError, Level, fresh_id};
|
|
|
|
/// Typing context
|
|
pub type Context = Vec<(String, Type)>;
|
|
|
|
/// Result of type checking
|
|
pub type CheckResult<T> = Result<T, TypeError>;
|
|
|
|
/// Bidirectional type checker for HoTT
|
|
#[derive(Clone)]
|
|
pub struct TypeChecker {
|
|
/// Typing context: variable -> type bindings
|
|
context: Context,
|
|
/// Universe level constraints
|
|
level_constraints: HashMap<String, Level>,
|
|
/// Normalization cache
|
|
cache: HashMap<u64, Term>,
|
|
}
|
|
|
|
impl TypeChecker {
|
|
/// Create a new type checker with empty context
|
|
pub fn new() -> Self {
|
|
TypeChecker {
|
|
context: Vec::new(),
|
|
level_constraints: HashMap::new(),
|
|
cache: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Create type checker with initial context
|
|
pub fn with_context(&self, ctx: Context) -> Self {
|
|
TypeChecker {
|
|
context: ctx,
|
|
level_constraints: self.level_constraints.clone(),
|
|
cache: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Extend context with a new binding
|
|
pub fn extend(&self, var: String, ty: Type) -> Self {
|
|
let mut new_ctx = self.context.clone();
|
|
new_ctx.push((var, ty));
|
|
TypeChecker {
|
|
context: new_ctx,
|
|
level_constraints: self.level_constraints.clone(),
|
|
cache: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Look up variable in context
|
|
pub fn lookup(&self, var: &str) -> Option<&Type> {
|
|
self.context.iter().rev()
|
|
.find(|(v, _)| v == var)
|
|
.map(|(_, ty)| ty)
|
|
}
|
|
|
|
/// Type checking: verify term has expected type
|
|
pub fn check(&self, term: &Term, expected: &Type) -> CheckResult<()> {
|
|
match (term, expected) {
|
|
// Check lambda against Pi-type
|
|
(Term::Lambda { var, body }, Type::Pi { domain, codomain, .. }) => {
|
|
let extended = self.extend(var.clone(), (**domain).clone());
|
|
let codomain_ty = codomain(&Term::Var(var.clone()));
|
|
extended.check(body, &codomain_ty)
|
|
}
|
|
|
|
// Check lambda against arrow type
|
|
(Term::Lambda { var, body }, Type::Arrow(domain, codomain)) => {
|
|
let extended = self.extend(var.clone(), (**domain).clone());
|
|
extended.check(body, codomain)
|
|
}
|
|
|
|
// Check pair against Sigma-type
|
|
(Term::Pair { fst, snd }, Type::Sigma { base, fiber, .. }) => {
|
|
self.check(fst, base)?;
|
|
let fiber_ty = fiber(fst);
|
|
self.check(snd, &fiber_ty)
|
|
}
|
|
|
|
// Check pair against product type
|
|
(Term::Pair { fst, snd }, Type::Product(left, right)) => {
|
|
self.check(fst, left)?;
|
|
self.check(snd, right)
|
|
}
|
|
|
|
// Check reflexivity against identity type
|
|
(Term::Refl(t), Type::Id(ty, left, right)) => {
|
|
self.check(t, ty)?;
|
|
// Verify t equals both left and right
|
|
let t_norm = self.normalize(t);
|
|
let left_norm = self.normalize(left);
|
|
let right_norm = self.normalize(right);
|
|
|
|
if !t_norm.structural_eq(&left_norm) || !t_norm.structural_eq(&right_norm) {
|
|
return Err(TypeError::TypeMismatch {
|
|
expected: format!("{:?} = {:?}", left, right),
|
|
found: format!("refl({:?})", t),
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// Check star against Unit
|
|
(Term::Star, Type::Unit) => Ok(()),
|
|
|
|
// Check true/false against Bool
|
|
(Term::True, Type::Bool) | (Term::False, Type::Bool) => Ok(()),
|
|
|
|
// Check zero against Nat
|
|
(Term::Zero, Type::Nat) => Ok(()),
|
|
|
|
// Check natural literal against Nat
|
|
(Term::NatLit(_), Type::Nat) => Ok(()),
|
|
|
|
// Check successor against Nat
|
|
(Term::Succ(n), Type::Nat) => self.check(n, &Type::Nat),
|
|
|
|
// Check injections against coproduct
|
|
(Term::Inl(t), Type::Coprod(left, _)) => self.check(t, left),
|
|
(Term::Inr(t), Type::Coprod(_, right)) => self.check(t, right),
|
|
|
|
// Fall back to inference and comparison
|
|
_ => {
|
|
let inferred = self.infer(term)?;
|
|
if self.types_equal(&inferred, expected) {
|
|
Ok(())
|
|
} else {
|
|
Err(TypeError::TypeMismatch {
|
|
expected: format!("{:?}", expected),
|
|
found: format!("{:?}", inferred),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Type inference: synthesize the type of a term
|
|
pub fn infer(&self, term: &Term) -> CheckResult<Type> {
|
|
match term {
|
|
// Variable lookup
|
|
Term::Var(name) => {
|
|
self.lookup(name)
|
|
.cloned()
|
|
.ok_or_else(|| TypeError::UnboundVariable(name.clone()))
|
|
}
|
|
|
|
// Star has type Unit
|
|
Term::Star => Ok(Type::Unit),
|
|
|
|
// Booleans
|
|
Term::True | Term::False => Ok(Type::Bool),
|
|
|
|
// Naturals
|
|
Term::Zero | Term::NatLit(_) => Ok(Type::Nat),
|
|
Term::Succ(n) => {
|
|
self.check(n, &Type::Nat)?;
|
|
Ok(Type::Nat)
|
|
}
|
|
|
|
// Application
|
|
Term::App { func, arg } => {
|
|
let func_ty = self.infer(func)?;
|
|
match func_ty {
|
|
Type::Pi { domain, codomain, .. } => {
|
|
self.check(arg, &domain)?;
|
|
Ok(codomain(arg))
|
|
}
|
|
Type::Arrow(domain, codomain) => {
|
|
self.check(arg, &domain)?;
|
|
Ok(*codomain)
|
|
}
|
|
_ => Err(TypeError::NotAFunction(format!("{:?}", func_ty))),
|
|
}
|
|
}
|
|
|
|
// First projection
|
|
Term::Fst(p) => {
|
|
let p_ty = self.infer(p)?;
|
|
match p_ty {
|
|
Type::Sigma { base, .. } => Ok(*base),
|
|
Type::Product(left, _) => Ok(*left),
|
|
_ => Err(TypeError::NotAPair(format!("{:?}", p_ty))),
|
|
}
|
|
}
|
|
|
|
// Second projection
|
|
Term::Snd(p) => {
|
|
let p_ty = self.infer(p)?;
|
|
match &p_ty {
|
|
Type::Sigma { fiber, .. } => {
|
|
let fst_val = Term::Fst(Box::new((**p).clone()));
|
|
Ok(fiber(&fst_val))
|
|
}
|
|
Type::Product(_, right) => Ok((**right).clone()),
|
|
_ => Err(TypeError::NotAPair(format!("{:?}", p_ty))),
|
|
}
|
|
}
|
|
|
|
// Reflexivity
|
|
Term::Refl(t) => {
|
|
let ty = self.infer(t)?;
|
|
Ok(Type::Id(Box::new(ty), Box::new((**t).clone()), Box::new((**t).clone())))
|
|
}
|
|
|
|
// Transport
|
|
Term::Transport { family, path, term: inner } => {
|
|
// Check that path is an identity type
|
|
let path_ty = self.infer(path)?;
|
|
match path_ty {
|
|
Type::Id(base_ty, source, target) => {
|
|
// Family should map the base type to types
|
|
// For simplicity, assume family is well-typed
|
|
let source_fiber = self.apply_family(family, &source)?;
|
|
self.check(inner, &source_fiber)?;
|
|
let target_fiber = self.apply_family(family, &target)?;
|
|
Ok(target_fiber)
|
|
}
|
|
_ => Err(TypeError::InvalidTransport(
|
|
"Expected identity type".to_string()
|
|
)),
|
|
}
|
|
}
|
|
|
|
// J-eliminator
|
|
Term::J { motive, base_case, left, right, path } => {
|
|
// Verify path type
|
|
let path_ty = self.infer(path)?;
|
|
match path_ty {
|
|
Type::Id(ty, source, target) => {
|
|
// Verify left and right match the path
|
|
if !source.structural_eq(left) || !target.structural_eq(right) {
|
|
return Err(TypeError::InvalidPathInduction(
|
|
"Path endpoints don't match".to_string()
|
|
));
|
|
}
|
|
// The result type is C(left, right, path)
|
|
// For simplicity, use the base case type
|
|
self.infer(base_case)
|
|
}
|
|
_ => Err(TypeError::InvalidPathInduction(
|
|
"Expected identity type".to_string()
|
|
)),
|
|
}
|
|
}
|
|
|
|
// If-then-else
|
|
Term::If { cond, then_branch, else_branch } => {
|
|
self.check(cond, &Type::Bool)?;
|
|
let then_ty = self.infer(then_branch)?;
|
|
self.check(else_branch, &then_ty)?;
|
|
Ok(then_ty)
|
|
}
|
|
|
|
// Natural number recursion
|
|
Term::NatRec { zero_case, succ_case, target } => {
|
|
self.check(target, &Type::Nat)?;
|
|
let result_ty = self.infer(zero_case)?;
|
|
// Verify succ_case has type Nat -> result_ty -> result_ty
|
|
let expected_succ_ty = Type::arrow(
|
|
Type::Nat,
|
|
Type::arrow(result_ty.clone(), result_ty.clone()),
|
|
);
|
|
self.check(succ_case, &expected_succ_ty)?;
|
|
Ok(result_ty)
|
|
}
|
|
|
|
// Case analysis on coproduct
|
|
Term::Case { scrutinee, left_case, right_case } => {
|
|
let scrut_ty = self.infer(scrutinee)?;
|
|
match scrut_ty {
|
|
Type::Coprod(left_ty, right_ty) => {
|
|
let left_result = self.infer(left_case)?;
|
|
match left_result {
|
|
Type::Arrow(_, result) => {
|
|
// Verify right case has matching type
|
|
let expected_right = Type::arrow(*right_ty, *result.clone());
|
|
self.check(right_case, &expected_right)?;
|
|
Ok(*result)
|
|
}
|
|
_ => Err(TypeError::NotAFunction(format!("{:?}", left_result))),
|
|
}
|
|
}
|
|
_ => Err(TypeError::TypeMismatch {
|
|
expected: "coproduct type".to_string(),
|
|
found: format!("{:?}", scrut_ty),
|
|
}),
|
|
}
|
|
}
|
|
|
|
// Abort (ex falso)
|
|
Term::Abort(t) => {
|
|
self.check(t, &Type::Empty)?;
|
|
// Can return any type - for inference, return a type variable
|
|
Ok(Type::Var(format!("?{}", fresh_id())))
|
|
}
|
|
|
|
// Path composition
|
|
Term::PathCompose { left, right } => {
|
|
let left_ty = self.infer(left)?;
|
|
let right_ty = self.infer(right)?;
|
|
|
|
match (&left_ty, &right_ty) {
|
|
(Type::Id(ty1, a, b), Type::Id(ty2, c, d)) => {
|
|
if !ty1.structural_eq(ty2) {
|
|
return Err(TypeError::TypeMismatch {
|
|
expected: format!("{:?}", ty1),
|
|
found: format!("{:?}", ty2),
|
|
});
|
|
}
|
|
if !b.structural_eq(c) {
|
|
return Err(TypeError::PathMismatch {
|
|
left_target: format!("{:?}", b),
|
|
right_source: format!("{:?}", c),
|
|
});
|
|
}
|
|
Ok(Type::Id(ty1.clone(), a.clone(), d.clone()))
|
|
}
|
|
_ => Err(TypeError::TypeMismatch {
|
|
expected: "identity types".to_string(),
|
|
found: format!("{:?} and {:?}", left_ty, right_ty),
|
|
}),
|
|
}
|
|
}
|
|
|
|
// Path inverse
|
|
Term::PathInverse(p) => {
|
|
let p_ty = self.infer(p)?;
|
|
match p_ty {
|
|
Type::Id(ty, a, b) => Ok(Type::Id(ty, b, a)), // a and b are already Box<Term>
|
|
_ => Err(TypeError::TypeMismatch {
|
|
expected: "identity type".to_string(),
|
|
found: format!("{:?}", p_ty),
|
|
}),
|
|
}
|
|
}
|
|
|
|
// ap
|
|
Term::Ap { func, path } => {
|
|
let func_ty = self.infer(func)?;
|
|
let path_ty = self.infer(path)?;
|
|
|
|
match (&func_ty, &path_ty) {
|
|
(Type::Arrow(domain, codomain), Type::Id(ty, a, b)) => {
|
|
if !domain.structural_eq(ty) {
|
|
return Err(TypeError::TypeMismatch {
|
|
expected: format!("{:?}", domain),
|
|
found: format!("{:?}", ty),
|
|
});
|
|
}
|
|
let fa = Term::App {
|
|
func: Box::new((**func).clone()),
|
|
arg: a.clone(),
|
|
};
|
|
let fb = Term::App {
|
|
func: Box::new((**func).clone()),
|
|
arg: b.clone(),
|
|
};
|
|
Ok(Type::Id(codomain.clone(), Box::new(fa), Box::new(fb)))
|
|
}
|
|
(Type::Pi { domain, codomain, .. }, Type::Id(ty, a, b)) => {
|
|
if !domain.structural_eq(ty) {
|
|
return Err(TypeError::TypeMismatch {
|
|
expected: format!("{:?}", domain),
|
|
found: format!("{:?}", ty),
|
|
});
|
|
}
|
|
let fa = Term::App {
|
|
func: Box::new((**func).clone()),
|
|
arg: a.clone(),
|
|
};
|
|
let fb = Term::App {
|
|
func: Box::new((**func).clone()),
|
|
arg: b.clone(),
|
|
};
|
|
// For Pi-types, compute the codomain at b
|
|
let result_ty = codomain(&b);
|
|
Ok(Type::Id(Box::new(result_ty), Box::new(fa), Box::new(fb)))
|
|
}
|
|
_ => Err(TypeError::TypeMismatch {
|
|
expected: "function and identity type".to_string(),
|
|
found: format!("{:?} and {:?}", func_ty, path_ty),
|
|
}),
|
|
}
|
|
}
|
|
|
|
// Let binding
|
|
Term::Let { var, value, body } => {
|
|
let value_ty = self.infer(value)?;
|
|
let extended = self.extend(var.clone(), value_ty);
|
|
extended.infer(body)
|
|
}
|
|
|
|
// Type annotation
|
|
Term::Annot { term: inner, ty } => {
|
|
self.check(inner, ty)?;
|
|
Ok((**ty).clone())
|
|
}
|
|
|
|
// Circle
|
|
Term::CircleBase => Ok(Type::Circle),
|
|
Term::CircleLoop => Ok(Type::Id(
|
|
Box::new(Type::Circle),
|
|
Box::new(Term::CircleBase),
|
|
Box::new(Term::CircleBase),
|
|
)),
|
|
|
|
// Interval
|
|
Term::IntervalZero | Term::IntervalOne => Ok(Type::Interval),
|
|
|
|
// Truncation
|
|
Term::Truncate(t) => {
|
|
let ty = self.infer(t)?;
|
|
Ok(Type::Truncation {
|
|
inner: Box::new(ty),
|
|
level: 0, // Default to set-truncation
|
|
})
|
|
}
|
|
|
|
// Coproduct injections need type annotation for full inference
|
|
Term::Inl(_) | Term::Inr(_) => {
|
|
Err(TypeError::CannotInfer("injection without type annotation".to_string()))
|
|
}
|
|
|
|
// Pair needs type annotation for dependent pairs
|
|
Term::Pair { fst, snd } => {
|
|
let fst_ty = self.infer(fst)?;
|
|
let snd_ty = self.infer(snd)?;
|
|
Ok(Type::Product(Box::new(fst_ty), Box::new(snd_ty)))
|
|
}
|
|
|
|
// Lambda needs type annotation
|
|
Term::Lambda { .. } => {
|
|
Err(TypeError::CannotInfer("lambda without type annotation".to_string()))
|
|
}
|
|
|
|
// apd
|
|
Term::Apd { func, path } => {
|
|
// Similar to ap but for dependent functions
|
|
let path_ty = self.infer(path)?;
|
|
match path_ty {
|
|
Type::Id(_, _, _) => {
|
|
// Result is a dependent path
|
|
self.infer(func)
|
|
}
|
|
_ => Err(TypeError::TypeMismatch {
|
|
expected: "identity type".to_string(),
|
|
found: format!("{:?}", path_ty),
|
|
}),
|
|
}
|
|
}
|
|
|
|
Term::InternalId(_) => Err(TypeError::CannotInfer("internal id".to_string())),
|
|
}
|
|
}
|
|
|
|
/// Normalize a term (beta reduction)
|
|
pub fn normalize(&self, term: &Term) -> Term {
|
|
match term {
|
|
// Beta reduction for application
|
|
Term::App { func, arg } => {
|
|
let func_norm = self.normalize(func);
|
|
let arg_norm = self.normalize(arg);
|
|
|
|
match func_norm {
|
|
Term::Lambda { var, body } => {
|
|
let subst = body.subst(&var, &arg_norm);
|
|
self.normalize(&subst)
|
|
}
|
|
_ => Term::App {
|
|
func: Box::new(func_norm),
|
|
arg: Box::new(arg_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Projection reduction
|
|
Term::Fst(p) => {
|
|
let p_norm = self.normalize(p);
|
|
match p_norm {
|
|
Term::Pair { fst, .. } => self.normalize(&fst),
|
|
_ => Term::Fst(Box::new(p_norm)),
|
|
}
|
|
}
|
|
|
|
Term::Snd(p) => {
|
|
let p_norm = self.normalize(p);
|
|
match p_norm {
|
|
Term::Pair { snd, .. } => self.normalize(&snd),
|
|
_ => Term::Snd(Box::new(p_norm)),
|
|
}
|
|
}
|
|
|
|
// If reduction
|
|
Term::If { cond, then_branch, else_branch } => {
|
|
let cond_norm = self.normalize(cond);
|
|
match cond_norm {
|
|
Term::True => self.normalize(then_branch),
|
|
Term::False => self.normalize(else_branch),
|
|
_ => Term::If {
|
|
cond: Box::new(cond_norm),
|
|
then_branch: Box::new(self.normalize(then_branch)),
|
|
else_branch: Box::new(self.normalize(else_branch)),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Natural recursion reduction
|
|
Term::NatRec { zero_case, succ_case, target } => {
|
|
let target_norm = self.normalize(target);
|
|
match target_norm {
|
|
Term::Zero | Term::NatLit(0) => self.normalize(zero_case),
|
|
Term::Succ(n) => {
|
|
let rec_result = Term::NatRec {
|
|
zero_case: zero_case.clone(),
|
|
succ_case: succ_case.clone(),
|
|
target: n.clone(),
|
|
};
|
|
let app1 = Term::App {
|
|
func: succ_case.clone(),
|
|
arg: n.clone(),
|
|
};
|
|
let app2 = Term::App {
|
|
func: Box::new(app1),
|
|
arg: Box::new(rec_result),
|
|
};
|
|
self.normalize(&app2)
|
|
}
|
|
Term::NatLit(n) if n > 0 => {
|
|
let pred = Term::NatLit(n - 1);
|
|
let rec_result = Term::NatRec {
|
|
zero_case: zero_case.clone(),
|
|
succ_case: succ_case.clone(),
|
|
target: Box::new(pred.clone()),
|
|
};
|
|
let app1 = Term::App {
|
|
func: succ_case.clone(),
|
|
arg: Box::new(pred),
|
|
};
|
|
let app2 = Term::App {
|
|
func: Box::new(app1),
|
|
arg: Box::new(rec_result),
|
|
};
|
|
self.normalize(&app2)
|
|
}
|
|
_ => Term::NatRec {
|
|
zero_case: Box::new(self.normalize(zero_case)),
|
|
succ_case: Box::new(self.normalize(succ_case)),
|
|
target: Box::new(target_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Case reduction
|
|
Term::Case { scrutinee, left_case, right_case } => {
|
|
let scrut_norm = self.normalize(scrutinee);
|
|
match scrut_norm {
|
|
Term::Inl(x) => {
|
|
let app = Term::App {
|
|
func: left_case.clone(),
|
|
arg: x,
|
|
};
|
|
self.normalize(&app)
|
|
}
|
|
Term::Inr(x) => {
|
|
let app = Term::App {
|
|
func: right_case.clone(),
|
|
arg: x,
|
|
};
|
|
self.normalize(&app)
|
|
}
|
|
_ => Term::Case {
|
|
scrutinee: Box::new(scrut_norm),
|
|
left_case: Box::new(self.normalize(left_case)),
|
|
right_case: Box::new(self.normalize(right_case)),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Let reduction
|
|
Term::Let { var, value, body } => {
|
|
let value_norm = self.normalize(value);
|
|
let subst = body.subst(var, &value_norm);
|
|
self.normalize(&subst)
|
|
}
|
|
|
|
// Path composition with refl
|
|
Term::PathCompose { left, right } => {
|
|
let left_norm = self.normalize(left);
|
|
let right_norm = self.normalize(right);
|
|
|
|
match (&left_norm, &right_norm) {
|
|
(Term::Refl(_), _) => right_norm,
|
|
(_, Term::Refl(_)) => left_norm,
|
|
_ => Term::PathCompose {
|
|
left: Box::new(left_norm),
|
|
right: Box::new(right_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Path inverse of refl
|
|
Term::PathInverse(p) => {
|
|
let p_norm = self.normalize(p);
|
|
match p_norm {
|
|
Term::Refl(x) => Term::Refl(x),
|
|
_ => Term::PathInverse(Box::new(p_norm)),
|
|
}
|
|
}
|
|
|
|
// ap on refl
|
|
Term::Ap { func, path } => {
|
|
let func_norm = self.normalize(func);
|
|
let path_norm = self.normalize(path);
|
|
|
|
match &path_norm {
|
|
Term::Refl(x) => {
|
|
let fx = Term::App {
|
|
func: Box::new(func_norm),
|
|
arg: x.clone(),
|
|
};
|
|
Term::Refl(Box::new(self.normalize(&fx)))
|
|
}
|
|
_ => Term::Ap {
|
|
func: Box::new(func_norm),
|
|
path: Box::new(path_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Structural recursion
|
|
Term::Lambda { var, body } => Term::Lambda {
|
|
var: var.clone(),
|
|
body: Box::new(self.normalize(body)),
|
|
},
|
|
|
|
Term::Pair { fst, snd } => Term::Pair {
|
|
fst: Box::new(self.normalize(fst)),
|
|
snd: Box::new(self.normalize(snd)),
|
|
},
|
|
|
|
Term::Succ(n) => Term::Succ(Box::new(self.normalize(n))),
|
|
|
|
Term::Inl(t) => Term::Inl(Box::new(self.normalize(t))),
|
|
Term::Inr(t) => Term::Inr(Box::new(self.normalize(t))),
|
|
|
|
Term::Refl(t) => Term::Refl(Box::new(self.normalize(t))),
|
|
|
|
Term::Truncate(t) => Term::Truncate(Box::new(self.normalize(t))),
|
|
|
|
Term::Annot { term: inner, ty } => Term::Annot {
|
|
term: Box::new(self.normalize(inner)),
|
|
ty: ty.clone(),
|
|
},
|
|
|
|
// J-elimination on refl
|
|
Term::J { motive, base_case, left, right, path } => {
|
|
let path_norm = self.normalize(path);
|
|
match &path_norm {
|
|
Term::Refl(_) => {
|
|
// J(C, c, a, a, refl_a) = c(a)
|
|
let app = Term::App {
|
|
func: base_case.clone(),
|
|
arg: left.clone(),
|
|
};
|
|
self.normalize(&app)
|
|
}
|
|
_ => Term::J {
|
|
motive: Box::new(self.normalize(motive)),
|
|
base_case: Box::new(self.normalize(base_case)),
|
|
left: Box::new(self.normalize(left)),
|
|
right: Box::new(self.normalize(right)),
|
|
path: Box::new(path_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Transport on refl
|
|
Term::Transport { family, path, term: inner } => {
|
|
let path_norm = self.normalize(path);
|
|
match &path_norm {
|
|
Term::Refl(_) => self.normalize(inner),
|
|
_ => Term::Transport {
|
|
family: Box::new(self.normalize(family)),
|
|
path: Box::new(path_norm),
|
|
term: Box::new(self.normalize(inner)),
|
|
},
|
|
}
|
|
}
|
|
|
|
Term::Apd { func, path } => {
|
|
let path_norm = self.normalize(path);
|
|
match &path_norm {
|
|
Term::Refl(x) => {
|
|
let fx = Term::App {
|
|
func: func.clone(),
|
|
arg: x.clone(),
|
|
};
|
|
Term::Refl(Box::new(self.normalize(&fx)))
|
|
}
|
|
_ => Term::Apd {
|
|
func: Box::new(self.normalize(func)),
|
|
path: Box::new(path_norm),
|
|
},
|
|
}
|
|
}
|
|
|
|
Term::Abort(t) => Term::Abort(Box::new(self.normalize(t))),
|
|
|
|
// Values
|
|
Term::Var(_) | Term::Star | Term::True | Term::False |
|
|
Term::Zero | Term::NatLit(_) | Term::CircleBase | Term::CircleLoop |
|
|
Term::IntervalZero | Term::IntervalOne | Term::InternalId(_) => term.clone(),
|
|
}
|
|
}
|
|
|
|
/// Check if two types are equal (up to beta-eta equality)
|
|
pub fn types_equal(&self, t1: &Type, t2: &Type) -> bool {
|
|
// First try structural equality
|
|
if t1.structural_eq(t2) {
|
|
return true;
|
|
}
|
|
|
|
// For more complex equality, we'd need to normalize type terms
|
|
// For now, use structural equality
|
|
false
|
|
}
|
|
|
|
/// Apply a type family (represented as a term) to a term
|
|
fn apply_family(&self, family: &Term, arg: &Term) -> CheckResult<Type> {
|
|
match family {
|
|
Term::Lambda { var, body } => {
|
|
let subst = body.subst(var, arg);
|
|
// Try to interpret the result as a type
|
|
self.term_to_type(&subst)
|
|
}
|
|
_ => {
|
|
// Try applying as a function
|
|
let app = Term::App {
|
|
func: Box::new(family.clone()),
|
|
arg: Box::new(arg.clone()),
|
|
};
|
|
self.term_to_type(&self.normalize(&app))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Try to interpret a term as a type
|
|
fn term_to_type(&self, term: &Term) -> CheckResult<Type> {
|
|
match term {
|
|
Term::Var(name) => Ok(Type::Var(name.clone())),
|
|
Term::Annot { ty, .. } => Ok((**ty).clone()),
|
|
_ => {
|
|
// For more complex cases, we'd need a more sophisticated approach
|
|
Ok(Type::Var(format!("{:?}", term)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for TypeChecker {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_infer_variable() {
|
|
let checker = TypeChecker::new().extend("x".to_string(), Type::Nat);
|
|
let result = checker.infer(&Term::Var("x".to_string()));
|
|
assert!(matches!(result, Ok(Type::Nat)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_refl() {
|
|
let checker = TypeChecker::new().extend("x".to_string(), Type::Nat);
|
|
let refl = Term::Refl(Box::new(Term::Var("x".to_string())));
|
|
let result = checker.infer(&refl);
|
|
|
|
assert!(matches!(result, Ok(Type::Id(_, _, _))));
|
|
}
|
|
|
|
#[test]
|
|
fn test_check_lambda() {
|
|
let checker = TypeChecker::new();
|
|
let identity = Term::Lambda {
|
|
var: "x".to_string(),
|
|
body: Box::new(Term::Var("x".to_string())),
|
|
};
|
|
let id_type = Type::arrow(Type::Nat, Type::Nat);
|
|
|
|
assert!(checker.check(&identity, &id_type).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_normalize_beta() {
|
|
let checker = TypeChecker::new();
|
|
|
|
// (fun x => x) 42
|
|
let app = Term::App {
|
|
func: Box::new(Term::Lambda {
|
|
var: "x".to_string(),
|
|
body: Box::new(Term::Var("x".to_string())),
|
|
}),
|
|
arg: Box::new(Term::NatLit(42)),
|
|
};
|
|
|
|
let result = checker.normalize(&app);
|
|
assert!(matches!(result, Term::NatLit(42)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_normalize_proj() {
|
|
let checker = TypeChecker::new();
|
|
|
|
// fst (1, 2)
|
|
let pair = Term::Pair {
|
|
fst: Box::new(Term::NatLit(1)),
|
|
snd: Box::new(Term::NatLit(2)),
|
|
};
|
|
let proj = Term::Fst(Box::new(pair));
|
|
|
|
let result = checker.normalize(&proj);
|
|
assert!(matches!(result, Term::NatLit(1)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_normalize_if() {
|
|
let checker = TypeChecker::new();
|
|
|
|
// if true then 1 else 2
|
|
let if_term = Term::If {
|
|
cond: Box::new(Term::True),
|
|
then_branch: Box::new(Term::NatLit(1)),
|
|
else_branch: Box::new(Term::NatLit(2)),
|
|
};
|
|
|
|
let result = checker.normalize(&if_term);
|
|
assert!(matches!(result, Term::NatLit(1)));
|
|
}
|
|
}
|