1 # -*- coding: utf-8 -*-
2 from pathlib import Path #从pathlib中导入Path
3 import os
4 import fileinput
5 import random
6 root_path='/home/tay/Videos/trash/垃圾分类项目/total/'
7 train = open('./trash_train.txt','a')
8 test = open('./trash_test.txt','a')
9 pwd = os.getcwd() +'/'# the val data path 训练集的路径
10
11
12 def gen_txt():
13 i =0
14 for file in os.listdir(root_path):
15 print('file is{}'.format(str(file)))
16 for init in os.listdir(os.path.join(root_path, file)): #子文件夹
17 print('init is{}'.format(str(init)))
18 i += 1
19 pathDir = os.listdir(os.path.join(root_path, file, init)) #
20 print('pathDir is', pathDir)
21 file_num = len(pathDir)
22 rate = 0.2
23 pick_num = int(file_num * rate)
24 sample = random.sample(pathDir, pick_num) #随机选取20%的pathDir字符串
25 print('sample is', sample)
26 for pick_name in sample:
27 test.write(root_path.split('total/')[-1] +file + '/' + init +'/' + pick_name + ' ' + str(i) + '\n')
28 # for name in pathDir: #文件夹中的图片名
29 # print('name is{}'.format(str(name)))
30 # if test
31 # total.write(root_path.split('total/')[-1] +file + '/' + init +'/' + name + ' ' + str(i) + '\n' )
32 same = [x for x in pathDir if x in sample] #列表中相同的内容
33 diff = [y for y in (sample + pathDir) if y not in same] #列表中不同的内容
34 print('different', diff)
35 print('same', same)
36 for train_name in diff:
37 train.write(root_path.split('total/')[-1] +file + '/' + init +'/' + train_name + ' ' + str(i) + '\n')
38 gen_txt()
采用了random.sample函数来随机选取特定数量的文件名作为测试集,通过比较两个列表中不同的元素来获取训练集的文件名。
总体上就是在进行字符串操作。