ID3算法python实现的问题
程序代码:
from numpy import * import math import copy import pickle as pp class ID3DTree(object): def __init__(self): self.tree = {} self.dataSet = [] self.labels = [] def loadDataSet(self,path,labels): recordlist = [] fp = open(path,"r") content = fp.read() fp.close() rowlist = content.splitlines() recordlist = [row.split("\t") for row in rowlist if row.strip()] self.dataSet = recordlist self.labels = labels def train(self): labels = copy.deepcopy(self.labels) self.tree = self.buildTree(self.dataSet,labels) def buildTree(self,dataSet,labels): catelist = [data[-1] for data in dataSet] if catelist.count(catelist[0]) == len(catelist): return catelist[0] if len(dataSet[0]) == 1: return self.maxCate(catelist) besfFeat = self.getBestFeat(dataSet) bestFeatLabel = labels[besfFeat] tree = {bestFeatLabel:{}} del(labels[besfFeat]) uniqueVals = set([data[bestFeat] for data in dataSet]) for value in uniqueVals: subLabels = labels[:] splitDataSet = self.splitDataSet(dataSet,besfFeat,value) subTree = self.buildTree(splitDataSet,subLabels) tree[bestFeatLabel][value] = subTree return tree def maxCate(self,catelist): items = dict([(catelist.count(i),i) for i in catelist]) return items[max(items.keys())] def getBestFeat(self,dataSet): numFeatures = len(dataSet[0]) - 1 baseEntropy = (dataSet) bestInfoGain = 0.0; besfFeature = -1 for i in range(numFeatures): uniqueVals = set([data[i] for data in dataSet]) newEntropy = 0.0 for value in uniqueVals: subDataSet = self.splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * (subDataSet) infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain besfFeature = id return besfFeature def computeEntropy(self,dataSet): datalen = float(len(dataSet)) catelist = [data[-1] for data in dataSet] items = dict([(i,catelist.count(i)) for i in catelist]) infoEntropy = 0.0 for key in items: prob = float(items[key])/datalen infoEntropy -= prob * math.log(prob,2) return infoEntropy def splitDataSet(self,dataSet,axis,value): rtnList = [] for featVec in dataSet: if featVec[axis] == value: rFeatVec = featVec[:axis] rFeatVec.extend(featVec[axis+1:]) rtnList.append(rFeatVec) return rtnList
这个是哪里错误了,怎么解决啊?