#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
 This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
  
May 2018 
    
bayesiannetwork.py

A class for representing Bayesian networks

@author: Stan (stalinmunoz@yahoo.com)
"""
from cpt import CPT
class BayesianNetwork:
    
    #Initializes the Bayesian network
    # cpts a dictionary with the conditional probability tables
    # the dictionary is indexed by the vertex variable
    def __init__(self,cpts=None):
        if cpts is not None:
            self.cpts = cpts
            self.variables = set(cpts.keys())
        else:
            self.cpts = {}
            self.variables = set()
        
    def setTable(self,table):
        self.cpts[table.variable]=table
        self.variables.add(table.variable)
        
    # computes the probability expressed in the given query
    def query(self,q):
        if "|" in q:
            s = q.split("|")
            p = self.query(s[0]+","+s[1])/self.query(s[1])
        else:
            mentioned = CPT.parseVariables(q)
            missing = set(self.variables).difference(mentioned)
            if not missing:
                p = self.atomic(q)
            else:
                p = 0
                for key in CPT.computeKeys(missing):
                    p += self.atomic(q + "," + key)
        return p

    # computes the probability of an atomic event            
    def atomic(self,q):
        p = 1
        q_literals = CPT.stringToLiteralList(q)
        for v in q_literals:
            s = ",".join(\
            map(lambda x:str(x), filter(lambda x:x.variable \
            in self.cpts[v.variable].parents,q_literals)))
            p *= self.cpts[v.variable][s] \
            if v.value else 1-self.cpts[v.variable][s]
        return p
    
    def __str__(self):
        s = ''
        for v in self.variables:
            s += str(self.cpts[v])+"\n"
        return s
            
        
        
    