sklearn加载本地下载的数据集

现在不下载了,但是找不到文件报错

img

为啥还是报错

img

def _download_20newsgroups(target_dir, cache_path):
    """Download the 20 newsgroups data and stored it as a zipped pickle."""
    train_path = os.path.join(target_dir, TRAIN_FOLDER)
    test_path = os.path.join(target_dir, TEST_FOLDER)

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    logger.info("Downloading dataset from %s (14 MB)", ARCHIVE.url)
    # archive_path = _fetch_remote(ARCHIVE, dirname=target_dir)

    archive_path = r'C:\Users\BigFish\scikit_learn_data\20news-bydate.tar.gz'


    logger.debug("Decompressing %s", archive_path)
    tarfile.open(archive_path, "r:gz").extractall(path=target_dir)
    os.remove(archive_path)

    # Store a zipped pickle
    cache = dict(
        train=load_files(train_path, encoding="latin1"),
        test=load_files(test_path, encoding="latin1"),
    )
    compressed_content = codecs.encode(pickle.dumps(cache), "zlib_codec")
    with open(cache_path, "wb") as f:
        f.write(compressed_content)

    shutil.rmtree(target_dir)
    return cache

img

img

【以下回答由 GPT 生成】

问题的关键在于如何使用sklearn加载本地下载的数据集。根据提供的代码和问题描述,可以看出代码中的下载和解压缩部分已经完成,并将数据存储为pickle文件,接下来将给出加载本地数据集的代码解决方案。

首先,确保已经安装了scikit-learn库。如果没有安装,可以通过以下命令来安装:

pip install -U scikit-learn

接下来,可以使用pickle库来加载pickle文件中的数据。在给出具体代码之前,需要执行以下几个步骤:

  1. 确保cache_path(pickle文件路径)是正确的,确保文件存在。
  2. 导入必要的库:
import pickle
from sklearn.datasets import load_files

下面是加载pickle文件中的数据的代码解决方案:

def load_data_from_pickle(cache_path):
    with open(cache_path, "rb") as f:
        compressed_content = f.read()
        cache = pickle.loads(compressed_content)
    return cache["train"], cache["test"]

调用该函数并获取训练和测试数据集:

train_data, test_data = load_data_from_pickle(cache_path)

至此,您可以使用train_datatest_data进行数据处理和模型训练。

以下是整个解决方案的总结代码:

import os
import tarfile
import pickle
import codecs
import shutil
import logging

from sklearn.datasets import load_files

def _download_20newsgroups(target_dir, cache_path):
    """下载20个新闻组数据并将其存储为压缩的pickle文件。"""
    train_path = os.path.join(target_dir, "train")
    test_path = os.path.join(target_dir, "test")

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    logger = logging.getLogger(__name__)
    logger.info("正在从%s下载数据集(14 MB)", ARCHIVE.url)

    archive_path = r'C:\Users\BigFish\scikit_learn_data\20news-bydate.tar.gz'

    logger.debug("正在解压缩%s", archive_path)
    tarfile.open(archive_path, "r:gz").extractall(path=target_dir)
    os.remove(archive_path)

    # 存储为压缩的pickle文件
    cache = dict(
        train=load_files(train_path, encoding="latin1"),
        test=load_files(test_path, encoding="latin1"),
    )
    compressed_content = codecs.encode(pickle.dumps(cache), "zlib_codec")
    with open(cache_path, "wb") as f:
        f.write(compressed_content)

    shutil.rmtree(target_dir)
    return cache

def load_data_from_pickle(cache_path):
    with open(cache_path, "rb") as f:
        compressed_content = f.read()
        cache = pickle.loads(compressed_content)
    return cache["train"], cache["test"]

# 假设cache_path是pickle文件路径
train_data, test_data = load_data_from_pickle(cache_path)

希望以上解决方案能对您有帮助。如果还有任何问题,请随时提问。



【相关推荐】



如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^