hwpforge_blueprint/
inheritance.rs

1//! Template inheritance resolution with DFS and circular detection.
2//!
3//! This module implements the inheritance chain resolution algorithm for
4//! Blueprint templates, supporting the `extends` keyword for template reuse.
5//!
6//! # Inheritance Model
7//!
8//! Templates can extend parent templates using the `meta.extends` field:
9//!
10//! ```yaml
11//! # parent.yaml
12//! meta:
13//!   name: parent
14//! styles:
15//!   body:
16//!     char_shape: { font: "Arial", size: 10pt }
17//!
18//! # child.yaml
19//! meta:
20//!   name: child
21//!   extends: parent
22//! styles:
23//!   body:
24//!     char_shape: { size: 12pt }  # Overrides only size, inherits font
25//! ```
26//!
27//! After resolution, the child template contains merged styles where child
28//! fields override parent fields.
29//!
30//! # Algorithm
31//!
32//! The resolution uses **depth-first search (DFS)** with cycle detection:
33//!
34//! 1. Start from the child template
35//! 2. Walk up the `extends` chain collecting ancestors
36//! 3. Detect circular inheritance (visited set)
37//! 4. Merge from root to child (parent first, child overrides)
38//! 5. Return fully resolved template with no `extends` field
39//!
40//! # Merge Semantics
41//!
42//! - **Styles**: Field-level merge (child fields override parent fields)
43//! - **Page**: Child's page entirely replaces parent's (if present)
44//! - **Markdown mapping**: Field-level merge (child entries override parent)
45
46use std::collections::{HashMap, HashSet};
47
48use indexmap::IndexMap;
49
50use crate::error::{BlueprintError, BlueprintResult};
51use crate::style::PartialStyle;
52use crate::template::{MarkdownMapping, Template};
53
54/// Maximum inheritance depth to prevent infinite recursion.
55pub const MAX_INHERITANCE_DEPTH: usize = 10;
56
57/// Trait for looking up templates by name during inheritance resolution.
58///
59/// This abstraction allows different template storage backends (HashMap,
60/// Vec, file system, etc.) without coupling the resolution algorithm to
61/// a specific implementation.
62pub trait TemplateProvider {
63    /// Retrieves a template by name.
64    ///
65    /// Returns `None` if the template does not exist.
66    fn get_template(&self, name: &str) -> Option<&Template>;
67}
68
69impl TemplateProvider for HashMap<String, Template> {
70    fn get_template(&self, name: &str) -> Option<&Template> {
71        self.get(name)
72    }
73}
74
75impl TemplateProvider for Vec<Template> {
76    fn get_template(&self, name: &str) -> Option<&Template> {
77        self.iter().find(|t| t.meta.name == name)
78    }
79}
80
81/// Resolves a template's inheritance chain into a fully merged template.
82///
83/// This function walks up the `extends` chain, merges styles from parent
84/// to child, and returns a new template with all inherited fields resolved.
85///
86/// # Errors
87///
88/// - [`BlueprintError::CircularInheritance`] if a cycle is detected
89/// - [`BlueprintError::TemplateNotFound`] if a parent template is missing
90/// - [`BlueprintError::InheritanceDepthExceeded`] if depth exceeds limit
91///
92/// # Example
93///
94/// ```ignore
95/// let templates = HashMap::from([
96///     ("base".into(), base_template),
97///     ("child".into(), child_template),
98/// ]);
99///
100/// let resolved = resolve_template(&child_template, &templates)?;
101/// assert!(resolved.meta.extends.is_none()); // No extends after resolution
102/// ```
103pub fn resolve_template(
104    template: &Template,
105    provider: &dyn TemplateProvider,
106) -> BlueprintResult<Template> {
107    // No inheritance chain: return clone
108    if template.meta.extends.is_none() {
109        return Ok(template.clone());
110    }
111
112    // Collect ancestors via DFS
113    let mut ancestors = Vec::new();
114    let mut visited = HashSet::new();
115    visited.insert(template.meta.name.clone()); // Mark starting template as visited
116    let mut current = template;
117    let mut chain = vec![template.meta.name.clone()];
118
119    while let Some(ref parent_name) = current.meta.extends {
120        // Depth limit check
121        if ancestors.len() >= MAX_INHERITANCE_DEPTH {
122            return Err(BlueprintError::InheritanceDepthExceeded {
123                depth: ancestors.len() + 1,
124                max: MAX_INHERITANCE_DEPTH,
125            });
126        }
127
128        // Circular detection (check before adding to visited)
129        if visited.contains(parent_name) {
130            chain.push(parent_name.clone());
131            return Err(BlueprintError::CircularInheritance { chain });
132        }
133        visited.insert(parent_name.clone());
134
135        // Lookup parent
136        let parent = provider
137            .get_template(parent_name)
138            .ok_or_else(|| BlueprintError::TemplateNotFound { name: parent_name.clone() })?;
139
140        ancestors.push(parent.clone());
141        chain.push(parent_name.clone());
142        current = parent;
143    }
144
145    // Merge from root to child (reverse order)
146    let mut merged = ancestors
147        .into_iter()
148        .rev()
149        .fold(template.clone(), |acc, parent| merge_templates(&parent, &acc));
150
151    // Clear extends after resolution
152    merged.meta.extends = None;
153
154    Ok(merged)
155}
156
157/// Merges a base template with a child template.
158///
159/// - **Styles**: For each style in child, merge with base; inherit base-only styles
160/// - **Page**: Child's page replaces base's (if present)
161/// - **Markdown mapping**: Field-level merge (child fields override base fields)
162///
163/// Returns a new template with merged data.
164fn merge_templates(base: &Template, child: &Template) -> Template {
165    // Merge styles: start with base, apply child overrides
166    let mut merged_styles: IndexMap<String, PartialStyle> = base.styles.clone();
167
168    for (name, child_style) in &child.styles {
169        if let Some(base_style) = merged_styles.get_mut(name) {
170            // Style exists in both: field-level merge
171            base_style.merge(child_style);
172        } else {
173            // New style in child: add it
174            merged_styles.insert(name.clone(), child_style.clone());
175        }
176    }
177
178    // Page: child replaces base entirely
179    let merged_page = child.page.clone().or_else(|| base.page.clone());
180
181    // Markdown mapping: field-level merge
182    let merged_md =
183        merge_markdown_mappings(base.markdown_mapping.as_ref(), child.markdown_mapping.as_ref());
184
185    Template {
186        meta: child.meta.clone(), // Child's meta takes precedence
187        page: merged_page,
188        styles: merged_styles,
189        markdown_mapping: merged_md,
190    }
191}
192
193/// Merges two MarkdownMapping structs (base + child override).
194fn merge_markdown_mappings(
195    base: Option<&MarkdownMapping>,
196    child: Option<&MarkdownMapping>,
197) -> Option<MarkdownMapping> {
198    match (base, child) {
199        (None, None) => None,
200        (Some(b), None) => Some(b.clone()),
201        (None, Some(c)) => Some(c.clone()),
202        (Some(b), Some(c)) => {
203            let mut merged = b.clone();
204            // Child fields override base fields (only if Some)
205            if c.body.is_some() {
206                merged.body.clone_from(&c.body);
207            }
208            if c.heading1.is_some() {
209                merged.heading1.clone_from(&c.heading1);
210            }
211            if c.heading2.is_some() {
212                merged.heading2.clone_from(&c.heading2);
213            }
214            if c.heading3.is_some() {
215                merged.heading3.clone_from(&c.heading3);
216            }
217            if c.heading4.is_some() {
218                merged.heading4.clone_from(&c.heading4);
219            }
220            if c.heading5.is_some() {
221                merged.heading5.clone_from(&c.heading5);
222            }
223            if c.heading6.is_some() {
224                merged.heading6.clone_from(&c.heading6);
225            }
226            if c.code.is_some() {
227                merged.code.clone_from(&c.code);
228            }
229            if c.blockquote.is_some() {
230                merged.blockquote.clone_from(&c.blockquote);
231            }
232            if c.list_item.is_some() {
233                merged.list_item.clone_from(&c.list_item);
234            }
235            Some(merged)
236        }
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Tests
242// ---------------------------------------------------------------------------
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::style::{PartialCharShape, PartialParaShape};
248    use crate::template::{PageStyle, TemplateMeta};
249    use hwpforge_foundation::{Alignment, HwpUnit};
250    use pretty_assertions::assert_eq;
251
252    /// Helper to create a minimal template for testing.
253    fn make_template(
254        name: &str,
255        extends: Option<&str>,
256        styles: Vec<(&str, PartialStyle)>,
257    ) -> Template {
258        Template {
259            meta: TemplateMeta {
260                name: name.to_string(),
261                version: "1.0.0".to_string(),
262                description: None,
263                extends: extends.map(|s| s.to_string()),
264            },
265            page: None,
266            styles: styles.into_iter().map(|(k, v)| (k.to_string(), v)).collect(),
267            markdown_mapping: None,
268        }
269    }
270
271    /// Helper to create a partial style with just char_shape font.
272    fn style_font(font: &str) -> PartialStyle {
273        PartialStyle {
274            char_shape: Some(PartialCharShape {
275                font: Some(font.to_string()),
276                ..Default::default()
277            }),
278            para_shape: None,
279        }
280    }
281
282    /// Helper to create a partial style with just char_shape size.
283    fn style_size(size: HwpUnit) -> PartialStyle {
284        PartialStyle {
285            char_shape: Some(PartialCharShape { size: Some(size), ..Default::default() }),
286            para_shape: None,
287        }
288    }
289
290    /// Helper to create a partial style with para_shape alignment.
291    fn style_align(align: Alignment) -> PartialStyle {
292        PartialStyle {
293            char_shape: None,
294            para_shape: Some(PartialParaShape { alignment: Some(align), ..Default::default() }),
295        }
296    }
297
298    #[test]
299    fn no_inheritance_returns_same_template() {
300        let tmpl = make_template("base", None, vec![("body", style_font("Arial"))]);
301        let provider = HashMap::<String, Template>::new();
302
303        let resolved = resolve_template(&tmpl, &provider).unwrap();
304
305        assert_eq!(resolved.meta.name, "base");
306        assert_eq!(resolved.meta.extends, None);
307        assert_eq!(resolved.styles.len(), 1);
308    }
309
310    #[test]
311    fn single_inheritance_merges_styles() {
312        let parent = make_template("parent", None, vec![("body", style_font("Arial"))]);
313        let child = make_template(
314            "child",
315            Some("parent"),
316            vec![("body", style_size(HwpUnit::from_pt(12.0).unwrap()))],
317        );
318
319        let provider = HashMap::from([
320            ("parent".to_string(), parent.clone()),
321            ("child".to_string(), child.clone()),
322        ]);
323
324        let resolved = resolve_template(&child, &provider).unwrap();
325
326        assert_eq!(resolved.meta.name, "child");
327        assert_eq!(resolved.meta.extends, None); // Cleared after resolution
328
329        let body_style = resolved.styles.get("body").unwrap();
330        assert_eq!(body_style.char_shape.as_ref().unwrap().font, Some("Arial".to_string()));
331        assert_eq!(
332            body_style.char_shape.as_ref().unwrap().size,
333            Some(HwpUnit::from_pt(12.0).unwrap())
334        );
335    }
336
337    #[test]
338    fn two_level_inheritance_merges_grandparent() {
339        let grandparent = make_template("grandparent", None, vec![("body", style_font("Times"))]);
340        let parent = make_template(
341            "parent",
342            Some("grandparent"),
343            vec![("body", style_size(HwpUnit::from_pt(10.0).unwrap()))],
344        );
345        let child =
346            make_template("child", Some("parent"), vec![("body", style_align(Alignment::Center))]);
347
348        let provider = HashMap::from([
349            ("grandparent".to_string(), grandparent),
350            ("parent".to_string(), parent),
351            ("child".to_string(), child.clone()),
352        ]);
353
354        let resolved = resolve_template(&child, &provider).unwrap();
355
356        let body = resolved.styles.get("body").unwrap();
357        assert_eq!(body.char_shape.as_ref().unwrap().font, Some("Times".to_string()));
358        assert_eq!(body.char_shape.as_ref().unwrap().size, Some(HwpUnit::from_pt(10.0).unwrap()));
359        assert_eq!(body.para_shape.as_ref().unwrap().alignment, Some(Alignment::Center));
360    }
361
362    #[test]
363    fn circular_two_cycle_detected() {
364        let a = make_template("a", Some("b"), vec![]);
365        let b = make_template("b", Some("a"), vec![]);
366
367        let provider = HashMap::from([("a".to_string(), a.clone()), ("b".to_string(), b)]);
368
369        let err = resolve_template(&a, &provider).unwrap_err();
370
371        match err {
372            BlueprintError::CircularInheritance { chain } => {
373                assert!(chain.contains(&"a".to_string()));
374                assert!(chain.contains(&"b".to_string()));
375                assert_eq!(chain.len(), 3); // a -> b -> a
376            }
377            _ => panic!("Expected CircularInheritance error, got {:?}", err),
378        }
379    }
380
381    #[test]
382    fn circular_self_reference_detected() {
383        let a = make_template("a", Some("a"), vec![]);
384        let provider = HashMap::from([("a".to_string(), a.clone())]);
385
386        let err = resolve_template(&a, &provider).unwrap_err();
387
388        match err {
389            BlueprintError::CircularInheritance { chain } => {
390                assert_eq!(chain, vec!["a".to_string(), "a".to_string()]);
391            }
392            _ => panic!("Expected CircularInheritance error"),
393        }
394    }
395
396    #[test]
397    fn template_not_found_error() {
398        let child = make_template("child", Some("missing"), vec![]);
399        let provider = HashMap::<String, Template>::new();
400
401        let err = resolve_template(&child, &provider).unwrap_err();
402
403        match err {
404            BlueprintError::TemplateNotFound { name } => {
405                assert_eq!(name, "missing");
406            }
407            _ => panic!("Expected TemplateNotFound error"),
408        }
409    }
410
411    #[test]
412    fn depth_limit_exceeded() {
413        // Create a chain of 11 templates (exceeds MAX_INHERITANCE_DEPTH = 10)
414        let mut templates = HashMap::new();
415        templates.insert("t0".to_string(), make_template("t0", None, vec![]));
416
417        for i in 1..=11 {
418            let parent_name = format!("t{}", i - 1);
419            let tmpl = make_template(&format!("t{}", i), Some(&parent_name), vec![]);
420            templates.insert(format!("t{}", i), tmpl);
421        }
422
423        let child = templates.get("t11").unwrap();
424        let err = resolve_template(child, &templates).unwrap_err();
425
426        match err {
427            BlueprintError::InheritanceDepthExceeded { depth, max } => {
428                assert!(depth > max);
429                assert_eq!(max, MAX_INHERITANCE_DEPTH);
430            }
431            _ => panic!("Expected InheritanceDepthExceeded error"),
432        }
433    }
434
435    #[test]
436    fn child_overrides_parent_field() {
437        let parent = make_template(
438            "parent",
439            None,
440            vec![(
441                "body",
442                PartialStyle {
443                    char_shape: Some(PartialCharShape {
444                        font: Some("Arial".to_string()),
445                        size: Some(HwpUnit::from_pt(10.0).unwrap()),
446                        bold: Some(false),
447                        ..Default::default()
448                    }),
449                    para_shape: None,
450                },
451            )],
452        );
453
454        let child = make_template(
455            "child",
456            Some("parent"),
457            vec![(
458                "body",
459                PartialStyle {
460                    char_shape: Some(PartialCharShape {
461                        bold: Some(true), // Override only bold
462                        ..Default::default()
463                    }),
464                    para_shape: None,
465                },
466            )],
467        );
468
469        let provider =
470            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
471
472        let resolved = resolve_template(&child, &provider).unwrap();
473        let body = resolved.styles.get("body").unwrap();
474
475        assert_eq!(body.char_shape.as_ref().unwrap().font, Some("Arial".to_string()));
476        assert_eq!(body.char_shape.as_ref().unwrap().size, Some(HwpUnit::from_pt(10.0).unwrap()));
477        assert_eq!(body.char_shape.as_ref().unwrap().bold, Some(true)); // Overridden
478    }
479
480    #[test]
481    fn parent_only_style_inherited() {
482        let parent = make_template(
483            "parent",
484            None,
485            vec![("body", style_font("Arial")), ("heading", style_font("Times"))],
486        );
487
488        let child = make_template(
489            "child",
490            Some("parent"),
491            vec![("body", style_size(HwpUnit::from_pt(12.0).unwrap()))], // Only overrides body
492        );
493
494        let provider =
495            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
496
497        let resolved = resolve_template(&child, &provider).unwrap();
498
499        assert!(resolved.styles.contains_key("body"));
500        assert!(resolved.styles.contains_key("heading")); // Inherited from parent
501        assert_eq!(
502            resolved.styles.get("heading").unwrap().char_shape.as_ref().unwrap().font,
503            Some("Times".to_string())
504        );
505    }
506
507    #[test]
508    fn child_page_replaces_parent_page() {
509        let parent = Template {
510            meta: TemplateMeta {
511                name: "parent".into(),
512                version: "1.0.0".into(),
513                description: None,
514                extends: None,
515            },
516            page: Some(PageStyle::a4()),
517            styles: IndexMap::new(),
518            markdown_mapping: None,
519        };
520
521        let child = Template {
522            meta: TemplateMeta {
523                name: "child".into(),
524                version: "1.0.0".into(),
525                description: None,
526                extends: Some("parent".into()),
527            },
528            page: Some(PageStyle::default()),
529            styles: IndexMap::new(),
530            markdown_mapping: None,
531        };
532
533        let provider =
534            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
535
536        let resolved = resolve_template(&child, &provider).unwrap();
537
538        // Child's page should be preserved (not parent's)
539        assert!(resolved.page.is_some());
540        // Child page is default (all None) since we used PageStyle::default()
541        assert!(resolved.page.as_ref().unwrap().width.is_none());
542    }
543
544    #[test]
545    fn no_child_page_inherits_parent_page() {
546        let parent = Template {
547            meta: TemplateMeta {
548                name: "parent".into(),
549                version: "1.0.0".into(),
550                description: None,
551                extends: None,
552            },
553            page: Some(PageStyle::a4()),
554            styles: IndexMap::new(),
555            markdown_mapping: None,
556        };
557
558        let child = Template {
559            meta: TemplateMeta {
560                name: "child".into(),
561                version: "1.0.0".into(),
562                description: None,
563                extends: Some("parent".into()),
564            },
565            page: None, // No page in child
566            styles: IndexMap::new(),
567            markdown_mapping: None,
568        };
569
570        let provider =
571            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
572
573        let resolved = resolve_template(&child, &provider).unwrap();
574
575        assert!(resolved.page.is_some()); // Inherited from parent
576        assert!(resolved.page.as_ref().unwrap().width.is_some()); // A4 width
577    }
578
579    #[test]
580    fn markdown_mapping_child_overrides_parent_entries() {
581        let parent = Template {
582            meta: TemplateMeta {
583                name: "parent".into(),
584                version: "1.0.0".into(),
585                description: None,
586                extends: None,
587            },
588            page: None,
589            styles: IndexMap::new(),
590            markdown_mapping: Some(MarkdownMapping {
591                heading1: Some("heading1".to_string()),
592                heading2: Some("heading2".to_string()),
593                ..Default::default()
594            }),
595        };
596
597        let child = Template {
598            meta: TemplateMeta {
599                name: "child".into(),
600                version: "1.0.0".into(),
601                description: None,
602                extends: Some("parent".into()),
603            },
604            page: None,
605            styles: IndexMap::new(),
606            markdown_mapping: Some(MarkdownMapping {
607                heading1: Some("custom_h1".to_string()), // Override
608                heading3: Some("heading3".to_string()),  // Add new
609                ..Default::default()
610            }),
611        };
612
613        let provider =
614            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
615
616        let resolved = resolve_template(&child, &provider).unwrap();
617        let md = resolved.markdown_mapping.unwrap();
618
619        assert_eq!(md.heading1, Some("custom_h1".to_string())); // Overridden
620        assert_eq!(md.heading2, Some("heading2".to_string())); // Inherited
621        assert_eq!(md.heading3, Some("heading3".to_string())); // Added
622    }
623
624    #[test]
625    fn template_provider_hashmap_lookup() {
626        let tmpl = make_template("test", None, vec![]);
627        let provider = HashMap::from([("test".to_string(), tmpl.clone())]);
628
629        assert!(provider.get_template("test").is_some());
630        assert!(provider.get_template("missing").is_none());
631    }
632
633    #[test]
634    fn template_provider_vec_lookup() {
635        let t1 = make_template("t1", None, vec![]);
636        let t2 = make_template("t2", None, vec![]);
637        let provider = vec![t1, t2];
638
639        assert!(provider.get_template("t1").is_some());
640        assert!(provider.get_template("t2").is_some());
641        assert!(provider.get_template("missing").is_none());
642    }
643
644    #[test]
645    fn child_adds_new_style_not_in_parent() {
646        let parent = make_template("parent", None, vec![("body", style_font("Arial"))]);
647        let child = make_template(
648            "child",
649            Some("parent"),
650            vec![
651                ("body", style_size(HwpUnit::from_pt(12.0).unwrap())),
652                ("caption", style_font("Times")), // New style
653            ],
654        );
655
656        let provider =
657            HashMap::from([("parent".to_string(), parent), ("child".to_string(), child.clone())]);
658
659        let resolved = resolve_template(&child, &provider).unwrap();
660
661        assert_eq!(resolved.styles.len(), 2);
662        assert!(resolved.styles.contains_key("body"));
663        assert!(resolved.styles.contains_key("caption"));
664    }
665
666    #[test]
667    fn three_level_inheritance_chain() {
668        let root = make_template("root", None, vec![("s", style_font("A"))]);
669        let mid = make_template(
670            "mid",
671            Some("root"),
672            vec![("s", style_size(HwpUnit::from_pt(10.0).unwrap()))],
673        );
674        let leaf = make_template("leaf", Some("mid"), vec![("s", style_align(Alignment::Right))]);
675
676        let provider = HashMap::from([
677            ("root".to_string(), root),
678            ("mid".to_string(), mid),
679            ("leaf".to_string(), leaf.clone()),
680        ]);
681
682        let resolved = resolve_template(&leaf, &provider).unwrap();
683        let s = resolved.styles.get("s").unwrap();
684
685        assert_eq!(s.char_shape.as_ref().unwrap().font, Some("A".to_string()));
686        assert_eq!(s.char_shape.as_ref().unwrap().size, Some(HwpUnit::from_pt(10.0).unwrap()));
687        assert_eq!(s.para_shape.as_ref().unwrap().alignment, Some(Alignment::Right));
688    }
689
690    #[test]
691    fn circular_three_cycle_detected() {
692        let a = make_template("a", Some("b"), vec![]);
693        let b = make_template("b", Some("c"), vec![]);
694        let c = make_template("c", Some("a"), vec![]);
695
696        let provider = HashMap::from([
697            ("a".to_string(), a.clone()),
698            ("b".to_string(), b),
699            ("c".to_string(), c),
700        ]);
701
702        let err = resolve_template(&a, &provider).unwrap_err();
703
704        match err {
705            BlueprintError::CircularInheritance { chain } => {
706                assert_eq!(chain.len(), 4); // a -> b -> c -> a
707            }
708            _ => panic!("Expected CircularInheritance error"),
709        }
710    }
711}