1. 磐创AI-开放猫官方网站首页
  2. 机器学习
  3. TensorFlow

tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现

tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现

函数文档:https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/basic_rnn_seq2seq

import tensorflow as tf
import numpy as np

steps=10
batch_size=10
input_size=10

encoder_inputs = tf.placeholder("float", [None, steps, input_size])
decoder_inputs = tf.placeholder("float", [None, steps, input_size])

en_input=np.zeros(shape=[steps,batch_size,input_size])
de_input=np.zeros(shape=[steps,batch_size,input_size])

cell=tf.nn.rnn_cell.BasicLSTMCell(10)

def get_result(encoder_inputs,decoder_inputs,cell):
    encoder_inputs=tf.unstack(encoder_inputs,axis=1)
    decoder_inputs=tf.unstack(decoder_inputs,axis=1)
    result=tf.contrib.legacy_seq2seq.basic_rnn_seq2seq(
        encoder_inputs,
        decoder_inputs,
        cell,
        dtype=tf.float32,
        scope=None
    )
    return result
result=get_result(encoder_inputs,decoder_inputs,cell)

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    result_value=sess.run(result,feed_dict={encoder_inputs:en_input,decoder_inputs:de_input})
    print(result_value)

原创文章,作者:fendouai,如若转载,请注明出处:https://panchuang.net/2017/08/07/tf-contrib-legacy_seq2seq-basic_rnn_seq2seq-example/

发表评论

登录后才能评论

联系我们

400-800-8888

在线咨询:点击这里给我发消息

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息