文档格式如下
def displayInfo(state, county):
print(f"state_name:{state} county_name:{county}")
with open('data1.txt') as f:
flag = False
sum_ratio,cou = 0, 0
for i in f.readlines():
line_info = i.strip().split(",")
if state == line_info[1] and county == line_info[2]:
flag = True
sum_ratio += float(line_info[5])
cou += 1
print(f"Year:{line_info[0]} Ratio:{line_info[5]}")
if not flag:
print("Invalid location was entered.")
else:
print(f"avg_ratio:{sum_ratio/cou:.2f}")
if __name__ == '__main__':
state_ = input("input state name:")
county_ = input("input county name:")
displayInfo(state_, county_)
--result
input state name:Alabama
input county name:Autauga
state_name:Alabama county_name:Autauga
Year:2003 Ratio:48.4
Year:2004 Ratio:46.4
Year:2005 Ratio:44.1
Year:2006 Ratio:44.3
Year:2007 Ratio:43.7
Year:2008 Ratio:41.6
Year:2009 Ratio:38.7
Year:2010 Ratio:34.7
Year:2011 Ratio:31.9
Year:2012 Ratio:30.3
Year:2013 Ratio:27.8
Year:2014 Ratio:26.1
Year:2015 Ratio:24.7
Year:2016 Ratio:23.3
Year:2017 Ratio:22.4
Year:2018 Ratio:21.5
Year:2019 Ratio:21.4
Year:2020 Ratio:20.6
avg_ratio:32.88
def func1(s,c):
data=[]
with open('teenpreganc.txt') as f:
fs = f.readlines()
for i in fs[1:]:
data.append(i.rstrip().split(','))
res=[]
for j in data:
if j[1]==s and j[2]==c:
res.append(j)
print('Year:{},Birth Rate:{}'.format(j[0],j[5]))
if len(res)==0:
print('Invalid location was entered')
s=input("state:")
c=input("country:")
func1(s,c)
def get_data(county, state):
# 将输入转换为小写
county = county.lower()
state = state.lower()
result = []
with open('teenpreganc.txt') as f:
for i in f:
# 剔除前后的字符
i = i.strip()
temp = i.split(',')
if temp[1].lower() == state and temp[2].lower() == county:
result.append(temp)
sum_data = 0
for i in result:
print(f"year:{i[0]},rate:{i[5]}")
sum_data += float(i[5])
if len(result) == 0:
print("Invalid location was entered")
else:
# 求平均值,只保留2位小数
print(f"average:{round(sum_data / len(result), 2)}")
print("please input county:")
county = input()
print("please input state:")
state = input()
get_data(county, state)
看起来好像很难的样子,其实不是。
这个就是考察那些禁用函数的实现。
先把实现思路大概列出来,然后去实现就好。
使用自定义函数udf
//自定义聚合函数 UDAF 继承 UserDefinedAggregateFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction {
//聚合函数的输入数据的数据结构
override def inputSchema: StructType = {
// new StructType().add("age",LongType)
StructType(StructField("age", LongType) :: Nil)
}
//在缓存区内的数据结构 ageSum ageNum
//sum用来记录 所有年龄值相加的总和
// count 用来记录 相加个数的总和
override def bufferSchema: StructType = {
// new StructType().add("sum",LongType)
// .add("count",LongType)
StructType(StructField("num", LongType) :: StructField("count", LongType) :: Nil)
}
//定义当前函数返回值的类型
override def dataType: DataType = DoubleType
//聚合函数幂等
override def deterministic: Boolean = true
//初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L //记录 传入所有用户年龄相加的总和
buffer(1) = 0L //记录 传入所用用户年龄的个数
}
//传入的一条数据后需要处理
//将 Row(63)对象中的值与buffer(0)相加
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
//合并各个分区内的数据
//例如 p1(321,6) p2(128,2) p3(241,4)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//计算年龄相加的总和
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//总人数
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1).toDouble
}
}
//求和
class sumAge extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
StructType(StructField("age",LongType)::Nil)
}
override def bufferSchema: StructType = {
StructType(StructField("sum",LongType)::Nil)
}
override def dataType: DataType = LongType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
}
override def evaluate(buffer: Row): Any = buffer.getLong(0)
}