tf.nn.fixed_unigram_candidate_sampler的参数传递方法

目前想用fixed_unigram_candidate_sampler作负采样,但其中的参数unigrams要求是list类型的变量,在graph里面用占位得到的变量是tensor类型的,传进去总会报错,求问应该如何正确将unigrams传递进去?

```

train_inputs = tf.compat.v1.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.compat.v1.placeholder(tf.int64, shape=[batch_size, 1])

# unigram该如何定义呢?

loss = tf.reduce_mean(
    tf.nn.nce_loss(weights=nce_weights,
                   biases=nce_biases,
                   inputs=embed,
                   labels=train_labels,
                   num_sampled=num_sampled,
                   num_classes=vocabulary_size,
                   sampled_values=tf.nn.fixed_unigram_candidate_sampler(  # 负采样
                       true_classes=train_labels,
                       num_true=1,
                       num_sampled=num_sampled,
                       unique=True,
                       range_max=vocabulary_size,
                       unigrams=unigram
                   )
                   ))

```

定义一个变量,通过set方法对数据类型进行转换,转换为可以接受的数据类型就可以了

您好,我是有问必答小助手,你的问题已经有小伙伴为您解答了问题,您看下是否解决了您的问题,可以追评进行沟通哦~

如果有您比较满意的答案 / 帮您提供解决思路的答案,可以点击【采纳】按钮,给回答的小伙伴一些鼓励哦~~

ps:问答VIP仅需29元,即可享受5次/月 有问必答服务,了解详情>>>https://vip.csdn.net/askvip?utm_source=1146287632