BNFit Class¶
The BNFit class represents a fitted Bayesian Network with parameters learned from data. It provides convenient access to the learned conditional distributions and enables evaluation of the fitted model.
Overview¶
A BNFit object is typically created as a result of parameter learning using the BN.fit() method. It contains:
- The original DAG structure from the parent BN
- Learned conditional distributions with parameters fitted to data
- Metadata about the fitting process (e.g., number of estimated PMFs)
Key Features¶
- Access to Fitted Parameters: Direct access to learned conditional distributions
- Model Evaluation: Methods for assessing fit quality and making predictions
- Serialization: Save and load fitted models
- Integration: Works seamlessly with the broader BN ecosystem
Usage Patterns¶
The BNFit class is primarily used for:
- Model Inspection: Examining learned parameters and distributions
- Prediction: Making probabilistic predictions on new data
- Model Persistence: Saving fitted models for later use
- Model Comparison: Comparing different fitted models
Example Usage¶
from causaliq_core.bn import BN, CPT
from causaliq_core.graph import DAG
import pandas as pd
# Create a BN structure
dag = DAG(['A', 'B', 'C'], [('A', 'B'), ('B', 'C')])
cnd_specs = {
'A': CPT(values=['T', 'F']),
'B': CPT(values=['T', 'F'], parents=['A']),
'C': CPT(values=['T', 'F'], parents=['B'])
}
bn = BN(dag, cnd_specs)
# Prepare training data
data = pd.DataFrame({
'A': ['T', 'F', 'T', 'F', 'T'],
'B': ['T', 'F', 'T', 'T', 'F'],
'C': ['T', 'F', 'T', 'T', 'F']
})
# Fit the model
fitted_bn = BN.fit(bn.dag, data)
# Access fitted distributions
print("Fitted CPT for node B:")
print(fitted_bn.cnds['B'])
# Examine model structure
print(f"Number of free parameters: {fitted_bn.free_params}")
print(f"Nodes: {list(fitted_bn.cnds.keys())}")
Relationship to BN Class¶
The BNFit class is closely related to the BN class:
- BN: Represents a network structure with specified or unspecified parameters
- BNFit: Represents a network with parameters learned from data
- Conversion: BNFit objects can be used wherever BN objects are expected
API Reference¶
BNFit
¶
Interface for Bayesian Network parameter estimation and data access.
This interface provides the essential methods required for fitting conditional probability tables (CPT) and linear Gaussian models in Bayesian Networks, as well as data access methods for the BN class.
Implementing classes should provide: - A constructor that accepts df=DataFrame parameter for BN compatibility - All abstract methods defined below - Properties for data access (.nodes, .sample, .node_types)
Methods:
-
marginals–Return marginal counts for a node and its parents.
-
values–Return the (float) values for specified nodes.
-
write–Write data to file.
Attributes:
-
N(int) –Total sample size.
-
node_types(Dict[str, str]) –Node type mapping for each variable.
-
node_values(Dict[str, Dict]) –Node value counts for categorical variables.
-
nodes(Tuple[str, ...]) –Column names in the dataset.
-
sample(Any) –Access to underlying data sample.
N
abstractmethod
property
writable
¶
Total sample size.
Returns:
-
int–Current sample size being used.
node_types
abstractmethod
property
¶
Node type mapping for each variable.
Returns:
-
Dict[str, str]–Dictionary mapping node names to their types.
-
Format(Dict[str, str]) –{node: 'category' | 'continuous'}
node_values
abstractmethod
property
writable
¶
Node value counts for categorical variables.
Returns:
-
Dict[str, Dict]–Values and their counts of categorical nodes in sample.
-
Format(Dict[str, Dict]) –{node1: {val1: count1, val2: count2, ...}, ...}
nodes
abstractmethod
property
¶
Column names in the dataset.
Returns:
-
Tuple[str, ...]–Tuple of node names (column names) in the dataset.
sample
abstractmethod
property
¶
Access to underlying data sample.
Returns:
-
Any–The underlying DataFrame or data structure for direct access.
-
Any–Used for operations like .unique() on columns.
marginals
abstractmethod
¶
marginals(node: str, parents: Dict, values_reqd: bool = False) -> Tuple
Return marginal counts for a node and its parents.
Parameters:
-
(node¶str) –Node for which marginals required.
-
(parents¶Dict) –Dictionary {node: parents} for non-orphan nodes.
-
(values_reqd¶bool, default:False) –Whether parent and child values required.
Returns:
-
Tuple–Tuple of counts, and optionally, values:
-
Tuple–- ndarray counts: 2D array, rows=child, cols=parents
-
Tuple–- int maxcol: Maximum number of parental values
-
Tuple–- tuple rowval: Child values for each row
-
Tuple–- tuple colval: Parent combo (dict) for each col
Raises:
-
TypeError–For bad argument types.
values
abstractmethod
¶
values(nodes: Tuple[str, ...]) -> ndarray
Return the (float) values for specified nodes.
Suitable for passing into e.g. linear regression fitting.
Parameters:
-
(nodes¶Tuple[str, ...]) –Nodes for which data required.
Returns:
-
ndarray–Numpy array of values, each column for a node.
Raises:
-
TypeError–If bad argument type.
-
ValueError–If bad argument value.