""" Classes and methods to maintain any bibtex information that is stored outside the doctree. .. autoclass:: Citation :members: .. autoclass:: BibtexDomain :members: """ import ast import re from typing import ( TYPE_CHECKING, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, cast, ) import docutils.frontend import docutils.nodes import docutils.parsers.rst import docutils.utils import pybtex.plugin import pybtex_docutils import sphinx.util from pybtex.richtext import Tag from pybtex.style import FormattedEntry from pybtex.style.template import FieldIsMissing from sphinx.domains import Domain, ObjType from sphinx.errors import ExtensionError from sphinx.locale import _ import sphinxcontrib.bibtex.plugin from .bibfile import BibData, normpath_filename, process_bibdata from .citation_target import CitationTarget, parse_citation_targets from .roles import CiteRole from .style.referencing import BaseReferenceStyle, format_references from .style.template import SphinxReferenceInfo if TYPE_CHECKING: from pybtex.database import Entry from pybtex.style.formatting import BaseStyle from sphinx.addnodes import pending_xref from sphinx.application import Sphinx from sphinx.builders import Builder from sphinx.environment import BuildEnvironment from .directives import BibliographyKey, BibliographyValue from .roles import CitationRef logger = sphinx.util.logging.getLogger(__name__) def _raise_invalid_node(node): """Helper method to raise an exception when an invalid node is visited. """ raise ValueError("invalid node %s in filter expression" % node) class _FilterVisitor(ast.NodeVisitor): """Visit the abstract syntax tree of a parsed filter expression.""" entry = None """The bibliographic entry to which the filter must be applied.""" cited_docnames = False """The documents where the entry is cited (empty if not cited).""" def __init__(self, entry, docname, cited_docnames): self.entry = entry self.docname = docname self.cited_docnames = cited_docnames def visit_Module(self, node): if len(node.body) != 1: raise ValueError("filter expression cannot contain multiple expressions") return self.visit(node.body[0]) def visit_Expr(self, node): return self.visit(node.value) def visit_BoolOp(self, node): outcomes = (self.visit(value) for value in node.values) if isinstance(node.op, ast.And): return all(outcomes) elif isinstance(node.op, ast.Or): return any(outcomes) else: # pragma: no cover # there are no other boolean operators # so this code should never execute assert False, "unexpected boolean operator %s" % node.op def visit_UnaryOp(self, node): if isinstance(node.op, ast.Not): return not self.visit(node.operand) else: _raise_invalid_node(node) def visit_BinOp(self, node): left = self.visit(node.left) op = node.op right = self.visit(node.right) if isinstance(op, ast.Mod): # modulo operator is used for regular expression matching if not isinstance(left, str): raise ValueError("expected a string on left side of %s" % node.op) if not isinstance(right, str): raise ValueError("expected a string on right side of %s" % node.op) return re.search(right, left, re.IGNORECASE) elif isinstance(op, ast.BitOr): return left | right elif isinstance(op, ast.BitAnd): return left & right else: _raise_invalid_node(node) def visit_Compare(self, node): # keep it simple: binary comparators only if len(node.ops) != 1: raise ValueError("syntax for multiple comparators not supported") left = self.visit(node.left) op = node.ops[0] right = self.visit(node.comparators[0]) if isinstance(op, ast.Eq): return left == right elif isinstance(op, ast.NotEq): return left != right elif isinstance(op, ast.Lt): return left < right elif isinstance(op, ast.LtE): return left <= right elif isinstance(op, ast.Gt): return left > right elif isinstance(op, ast.GtE): return left >= right elif isinstance(op, ast.In): return left in right elif isinstance(op, ast.NotIn): return left not in right else: # not used currently: ast.Is | ast.IsNot _raise_invalid_node(op) def visit_Name(self, node): """Calculate the value of the given identifier.""" id_ = node.id if id_ == "type": return self.entry.type.lower() elif id_ == "key": return self.entry.key.lower() elif id_ == "cited": return bool(self.cited_docnames) elif id_ == "docname": return self.docname elif id_ == "docnames": return self.cited_docnames elif id_ == "author" or id_ == "editor": if id_ in self.entry.persons: return " and ".join( str(person) # XXX needs fix in pybtex? for person in self.entry.persons[id_] ) else: return "" else: return self.entry.fields.get(id_, "") def visit_Set(self, node): return frozenset(self.visit(elt) for elt in node.elts) # NameConstant is Python 3.4 only def visit_NameConstant(self, node): return node.value # pragma: no cover # Constant is Python 3.6+ only # Since 3.8 Num, Str, Bytes, NameConstant and Ellipsis are just Constant def visit_Constant(self, node): return node.value # Not used on 3.8+ def visit_Str(self, node): return node.s # pragma: no cover def generic_visit(self, node): _raise_invalid_node(node) def get_docnames(env): """Get document names in order.""" rel = env.collect_relations() docname = ( env.config.master_doc if sphinx.version_info < (4, 0) else env.config.root_doc ) docnames = set() while docname is not None: docnames.add(docname) yield docname parent, prevdoc, nextdoc = rel[docname] docname = nextdoc for docname in sorted(env.found_docs - docnames): yield docname class Citation(NamedTuple): """Information about a citation.""" citation_id: str #: Unique id of this citation. bibliography_key: "BibliographyKey" #: Key of its bibliography directive. key: str #: Key (with prefix). entry: "Entry" #: Entry from pybtex. formatted_entry: "FormattedEntry" #: Formatted entry for bibliography. tooltip_entry: Optional["FormattedEntry"] #: Formatted entry for tooltip. def env_updated(app: "Sphinx", env: "BuildEnvironment") -> Iterable[str]: domain = cast(BibtexDomain, env.get_domain("cite")) return domain.env_updated() def parse_header(header: str, source_path: str): parser = docutils.parsers.rst.Parser() # note: types stub for docutils doesn't know about components argument settings = docutils.frontend.OptionParser( components=(docutils.parsers.rst.Parser,) # type: ignore ).get_default_values() document = docutils.utils.new_document(source_path, settings) parser.parse(header, document) return document[0] class BibtexDomain(Domain): """Sphinx domain for the bibtex extension.""" name = "cite" label = "BibTeX Citations" data_version = 4 initial_data = dict( bibdata=BibData( encoding="", bibfiles={}, data=pybtex.database.BibliographyData() ), bibliography_header=docutils.nodes.container(), bibliographies={}, citations=[], citation_refs=[], ) backend = pybtex_docutils.Backend() reference_style: BaseReferenceStyle @property def bibdata(self) -> BibData: """Information about the bibliography files.""" return self.data["bibdata"] @property def bibliography_header(self) -> docutils.nodes.Element: return self.data["bibliography_header"] @property def bibliographies(self) -> Dict["BibliographyKey", "BibliographyValue"]: """Map storing information about each bibliography directive.""" return self.data["bibliographies"] @property def citations(self) -> List[Citation]: """Citation data.""" return self.data["citations"] @property def citation_refs(self) -> List["CitationRef"]: """Citation reference data.""" return self.data["citation_refs"] def __init__(self, env: "BuildEnvironment"): # set up referencing style style = sphinxcontrib.bibtex.plugin.find_plugin( "sphinxcontrib.bibtex.style.referencing", env.app.config.bibtex_reference_style, ) self.reference_style = style() # set up object types and roles for referencing style role_names = self.reference_style.role_names() self.object_types = dict( citation=ObjType(_("citation"), *role_names, searchprio=-1), ) self.roles = dict((name, CiteRole()) for name in role_names) # initialize the domain super().__init__(env) # connect env-updated env.app.connect("env-updated", env_updated) # check config if env.app.config.bibtex_bibfiles is None: raise ExtensionError("You must configure the bibtex_bibfiles setting") # update bib file information in the cache bibfiles = [ normpath_filename(env, "/" + bibfile) for bibfile in env.app.config.bibtex_bibfiles ] self.data["bibdata"] = process_bibdata( self.bibdata, bibfiles, env.app.config.bibtex_encoding ) # parse bibliography header header = getattr(env.app.config, "bibtex_bibliography_header") if header: self.data["bibliography_header"] = docutils.nodes.container() self.data["bibliography_header"] += parse_header( header, "bibliography_header" ) def clear_doc(self, docname: str) -> None: self.data["citations"] = [ citation for citation in self.citations if citation.bibliography_key.docname != docname ] self.data["citation_refs"] = [ ref for ref in self.citation_refs if ref.docname != docname ] for bib_key in list(self.bibliographies.keys()): if bib_key.docname == docname: del self.bibliographies[bib_key] def merge_domaindata(self, docnames: List[str], otherdata: Dict) -> None: for bib_key, bib_value in otherdata["bibliographies"].items(): if bib_key.docname in docnames: self.bibliographies[bib_key] = bib_value for citation_ref in otherdata["citation_refs"]: if citation_ref.docname in docnames: self.citation_refs.append(citation_ref) # 'citations' domain data calculated in env_updated def env_updated(self) -> Iterable[str]: # This function is called when all doctrees are parsed, # but before any post transforms are applied. We use it to # determine which citations will be added to which bibliography # directive, and also to format the labels. We need to format # the labels here because they must be known when resolve_xref is # called. self.citations.clear() # might have been restored from pickle docnames = list(get_docnames(self.env)) # we keep track of this to quickly check for duplicates used_keys: Set[str] = set() used_labels: Dict[str, str] = {} for bibliography_key, bibliography in self.bibliographies.items(): for entry, formatted_entry, tooltip_entry in self.get_formatted_entries( bibliography_key, docnames, self.env.app.config.bibtex_tooltips, self.env.app.config.bibtex_tooltips_style, ): key = bibliography.keyprefix + formatted_entry.key if bibliography.list_ == "citation" and key in used_keys: logger.warning( 'duplicate citation for key "%s"' % key, location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="duplicate_citation", ) self.citations.append( Citation( citation_id=bibliography.citation_nodes[key]["ids"][0], bibliography_key=bibliography_key, key=key, entry=entry, formatted_entry=formatted_entry, tooltip_entry=tooltip_entry, ) ) if bibliography.list_ == "citation": used_keys.add(key) if formatted_entry.label not in used_labels: used_labels[formatted_entry.label] = key elif used_labels[formatted_entry.label] != key: # if used_label[label] == key then already # duplicate key warning logger.warning( 'duplicate label "%s" for keys "%s" and "%s"' % ( formatted_entry.label, used_labels[formatted_entry.label], key, ), location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="duplicate_label", ) return [] # expects list of updated docnames def resolve_xref( self, env: "BuildEnvironment", fromdocname: str, builder: "Builder", typ: str, target: str, node: "pending_xref", contnode: docutils.nodes.Element, ) -> docutils.nodes.Element: """Replace node by list of citation references (one for each key).""" targets = parse_citation_targets(target) keys: Dict[str, CitationTarget] = {target2.key: target2 for target2 in targets} citations: Dict[str, Citation] = { cit.key: cit for cit in self.citations if cit.key in keys and self.bibliographies[cit.bibliography_key].list_ == "citation" } for key in keys: if key not in citations: logger.warning( 'could not find bibtex key "%s"' % key, location=node, type="bibtex", subtype="key_not_found", ) plaintext = pybtex.plugin.find_plugin("pybtex.backends", "plaintext")() references = [ ( citation.entry, citation.formatted_entry, SphinxReferenceInfo( builder=builder, fromdocname=fromdocname, todocname=citation.bibliography_key.docname, citation_id=citation.citation_id, title=( citation.tooltip_entry.text.render(plaintext).replace( "\\url ", "" ) if citation.tooltip_entry else None ), pre_text=keys[citation.key].pre, post_text=keys[citation.key].post, ), ) for citation in citations.values() ] formatted_references = format_references(self.reference_style, typ, references) result_node = docutils.nodes.inline(rawsource=target) result_node += formatted_references.render(self.backend) return result_node def resolve_any_xref( self, env: "BuildEnvironment", fromdocname: str, builder: "Builder", target: str, node: "pending_xref", contnode: docutils.nodes.Element, ) -> List[Tuple[str, docutils.nodes.Element]]: """Replace node by list of citation references (one for each key), provided that the target has citation keys. """ keys = [key.strip() for key in target.split(",")] citations: Set[str] = { cit.key for cit in self.citations if cit.key in keys and self.bibliographies[cit.bibliography_key].list_ == "citation" } if any(key in citations for key in keys): result_node = self.resolve_xref( env, fromdocname, builder, "p", target, node, contnode ) return [("p", result_node)] else: return [] def get_all_cited_keys(self, docnames): """Yield all citation keys for given *docnames* in order, then ordered by citation order. """ for citation_ref in sorted( self.citation_refs, key=lambda c: docnames.index(c.docname) ): for target in citation_ref.targets: yield target.key def get_entries(self, bibfiles: List[str]) -> Iterable["Entry"]: """Return all bibliography entries from the bib files, unsorted (i.e. in order of appearance in the bib files. """ for bibfile in bibfiles: for key in self.bibdata.bibfiles[bibfile].keys: yield self.bibdata.data.entries[key] def get_filtered_entries( self, bibliography_key: "BibliographyKey" ) -> Iterable[Tuple[str, "Entry"]]: """Return unsorted bibliography entries filtered by the filter expression. """ bibliography = self.bibliographies[bibliography_key] for entry in self.get_entries(bibliography.bibfiles): key = bibliography.keyprefix + entry.key cited_docnames = { citation_ref.docname for citation_ref in self.citation_refs if key in {target.key for target in citation_ref.targets} } visitor = _FilterVisitor( entry=entry, docname=bibliography_key.docname, cited_docnames=cited_docnames, ) try: success = visitor.visit(bibliography.filter_) except ValueError as err: logger.warning( "syntax error in :filter: expression; %s" % err, location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="filter_syntax_error", ) # recover by falling back to the default success = bool(cited_docnames) if success or entry.key in bibliography.keys: yield key, entry def get_sorted_entries( self, bibliography_key: "BibliographyKey", docnames: List[str] ) -> Iterable[Tuple[str, "Entry"]]: """Return filtered bibliography entries sorted by citation order.""" entries = dict(self.get_filtered_entries(bibliography_key)) for key in self.get_all_cited_keys(docnames): try: entry = entries.pop(key) except KeyError: pass else: yield key, entry # then all remaining keys, in order of bibliography file for key, entry in entries.items(): yield key, entry def get_formatted_entries( self, bibliography_key: "BibliographyKey", docnames: List[str], tooltips: bool, tooltips_style: str, ) -> Iterable[Tuple["Entry", "FormattedEntry", Optional["FormattedEntry"]]]: """Get sorted bibliography entries along with their pybtex labels, with additional sorting and formatting applied from the pybtex style. """ bibliography = self.bibliographies[bibliography_key] entries = dict(self.get_sorted_entries(bibliography_key, docnames)) style: BaseStyle = cast( "BaseStyle", pybtex.plugin.find_plugin("pybtex.style.formatting", bibliography.style)(), ) style2: Optional[BaseStyle] = ( ( cast( "BaseStyle", pybtex.plugin.find_plugin( "pybtex.style.formatting", tooltips_style )(), ) if tooltips_style else style ) if tooltips else None ) sorted_entries: Iterable[Entry] = style.sort(entries.values()) labels = style.format_labels(sorted_entries) for label, entry in zip(labels, sorted_entries): try: yield ( entry, style.format_entry(bibliography.labelprefix + label, entry), style2.format_entry(bibliography.labelprefix + label, entry) if style2 else None, ) except FieldIsMissing as exc: logger.warning( str(exc), location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="missing_field", ) formatted_error_entry = FormattedEntry( entry.key, Tag("b", str(exc)), bibliography.labelprefix + label ) yield entry, formatted_error_entry, None