注册 登录
编程论坛 Python论坛

ID3算法python实现的问题

大黄鸡 发布于 2018-04-26 16:20, 1735 次点击
程序代码:

from numpy import *
import math
import copy
import pickle as pp
class ID3DTree(object):
    def __init__(self):
        self.tree = {}
        self.dataSet = []
        self.labels = []

    def loadDataSet(self,path,labels):
        recordlist = []
        fp = open(path,"r")
        content = fp.read()
        fp.close()
        rowlist = content.splitlines()
        recordlist = [row.split("\t") for row in rowlist if row.strip()]
        self.dataSet = recordlist
        self.labels = labels

    def train(self):
        labels = copy.deepcopy(self.labels)
        self.tree = self.buildTree(self.dataSet,labels)

    def buildTree(self,dataSet,labels):
        catelist = [data[-1] for data in dataSet]
        if catelist.count(catelist[0]) == len(catelist):
            return catelist[0]
        if len(dataSet[0]) == 1:
            return self.maxCate(catelist)
        besfFeat = self.getBestFeat(dataSet)
        bestFeatLabel = labels[besfFeat]
        tree = {bestFeatLabel:{}}
        del(labels[besfFeat])
        uniqueVals = set([data[bestFeat] for data in dataSet])
        for value in uniqueVals:
            subLabels = labels[:]
            splitDataSet = self.splitDataSet(dataSet,besfFeat,value)
            subTree = self.buildTree(splitDataSet,subLabels)
            tree[bestFeatLabel][value] = subTree
        return tree

    def maxCate(self,catelist):
        items = dict([(catelist.count(i),i) for i in catelist])
        return items[max(items.keys())]

    def getBestFeat(self,dataSet):
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = (dataSet)
        bestInfoGain = 0.0;
        besfFeature = -1
        for i in range(numFeatures):
            uniqueVals = set([data[i] for data in dataSet])
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet,i,value)
                prob = len(subDataSet)/float(len(dataSet))
                newEntropy += prob * (subDataSet)
            infoGain = baseEntropy - newEntropy
            if (infoGain > bestInfoGain):
                bestInfoGain = infoGain
                besfFeature = id
        return besfFeature

    def computeEntropy(self,dataSet):
        datalen = float(len(dataSet))
        catelist = [data[-1] for data in dataSet]
        items = dict([(i,catelist.count(i)) for i in catelist])
        infoEntropy = 0.0
        for key in items:
            prob = float(items[key])/datalen
            infoEntropy -= prob * math.log(prob,2)
        return infoEntropy

    def splitDataSet(self,dataSet,axis,value):
        rtnList = []
        for featVec in dataSet:
            if featVec[axis] == value:
                rFeatVec = featVec[:axis]
                rFeatVec.extend(featVec[axis+1:])
                rtnList.append(rFeatVec)
        return rtnList
只有本站会员才能查看附件,请 登录

这个是哪里错误了,怎么解决啊?
0 回复
1