#include <string>
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/layers/sequence_layers.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void ALSTMLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
names->resize(2);
(*names)[0] = "h_0";
(*names)[1] = "c_0";
}
template <typename Dtype>
void ALSTMLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
names->resize(2);
(*names)[0] = "h_" + format_int(this->T_);
(*names)[1] = "c_T";
}
template <typename Dtype>
void ALSTMLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const {
const int num_output = this->layer_param_.recurrent_param().num_output();
const int num_blobs = 2;
shapes->resize(num_blobs);
for (int i = 0; i < num_blobs; ++i) {
(*shapes)[i].Clear();
(*shapes)[i].add_dim(1); // a single timestep
(*shapes)[i].add_dim(this->N_);
(*shapes)[i].add_dim(num_output);
}
}
template <typename Dtype>
void ALSTMLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
names->resize(2);
(*names)[0] = "h";
(*names)[1] = "mask";
}
template <typename Dtype>
void ALSTMLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
const int num_output = this->layer_param_.recurrent_param().num_output();
CHECK_GT(num_output, 0) << "num_output must be positive";
const FillerParameter& weight_filler =
this->layer_param_.recurrent_param().weight_filler();
const FillerParameter& bias_filler =
this->layer_param_.recurrent_param().bias_filler();
// Add generic LayerParameter's (without bottoms/tops) of layer types we'll
// use to save redundant code.
LayerParameter hidden_param;
hidden_param.set_type("InnerProduct");
hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4);
hidden_param.mutable_inner_product_param()->set_bias_term(false);
hidden_param.mutable_inner_product_param()->set_axis(1);
hidden_param.mutable_inner_product_param()->
mutable_weight_filler()->CopyFrom(weight_filler);
LayerParameter biased_hidden_param(hidden_param);
biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
biased_hidden_param.mutable_inner_product_param()->
mutable_bias_filler()->CopyFrom(bias_filler);
LayerParameter attention_param;
attention_param.set_type("InnerProduct");
attention_param.mutable_inner_product_param()->set_num_output(256);
attention_param.mutable_inner_product_param()->set_bias_term(false);
attention_param.mutable_inner_product_param()->set_axis(2);
attention_param.mutable_inner_product_param()->
mutable_weight_filler()->CopyFrom(weight_filler);
LayerParameter biased_attention_param(attention_param);
biased_attention_param.mutable_inner_product_param()->set_bias_term(true);
biased_attention_param.mutable_inner_product_param()->
mutable_bias_filler()->CopyFrom(bias_filler); // weight + bias
LayerParameter sum_param;
sum_param.set_type("Eltwise");
sum_param.mutable_eltwise_param()->set_operation(
EltwiseParameter_EltwiseOp_SUM);
LayerParameter slice_param;
slice_param.set_type("Slice");
slice_param.mutable_slice_param()->set_axis(0);
LayerParameter softmax_param;
softmax_param.set_type("Softmax");
softmax_param.mutable_softmax_param()->set_axis(-1);
LayerParameter split_param;
split_param.set_type("Split");
LayerParameter scale_param;
scale_param.set_type("Scale");
LayerParameter permute_param;
permute_param.set_type("Permute");
LayerParameter reshape_param;
reshape_param.set_type("Reshape");
LayerParameter bias_layer_param;
bias_layer_param.set_type("Bias");
LayerParameter pool_param;
pool_param.set_type("Pooling");
LayerParameter reshape_layer_param;
reshape_layer_param.set_type("Reshape");
BlobShape input_shape;
input_shape.add_dim(1); // c_0 and h_0 are a single timestep
input_shape.add_dim(this->N_);
input_shape.add_dim(num_output);
net_param->add_input("c_0");
net_param->add_input_shape()->CopyFrom(input_shape);
net_param->add_input("h_0");
net_param->add_input_shape()->CopyFrom(input_shape);
LayerParameter* cont_slice_param = net_param->add_layer();
cont_slice_param->CopyFrom(slice_param);
cont_slice_param->set_name("cont_slice");
cont_slice_param->add_bottom("cont");
cont_slice_param->mutable_slice_param()->set_axis(1);
LayerParameter* x_slice_param = net_param->add_layer();
x_slice_param->CopyFrom(slice_param);
x_slice_param->set_name("x_slice");
x_slice_param->add_bottom("x");
// Add layer to transform all timesteps of x to the hidden state dimension.
// W_xc_x = W_xc * x + b_c
/*
{
LayerParameter* x_transform_param = net_param->add_layer();
x_transform_param->CopyFrom(biased_hidden_param);
x_transform_param->set_name("x_transform");
x_transform_param->add_param()->set_name("W_xc");
x_transform_param->add_param()->set_name("b_c");
x_transform_param->add_bottom("x");
x_transform_param->add_top("W_xc_x");
}
if (this->static_input_) {
// Add layer to transform x_static to the gate dimension.
// W_xc_x_static = W_xc_static * x_static
LayerParameter* x_static_transform_param = net_param->add_layer();
x_static_transform_param->CopyFrom(hidden_param);
x_static_transform_param->mutable_inner_product_param()->set_axis(1);
x_static_transform_param->set_name("W_xc_x_static");
x_static_transform_param->add_param()->set_name("W_xc_static");
x_static_transform_param->add_bottom("x_static");
x_static_transform_param->add_top("W_xc_x_static");
LayerParameter* reshape_param = net_param->add_layer();
reshape_param->set_type("Reshape");
BlobShape* new_shape =
reshape_param->mutable_reshape_param()->mutable_shape();
new_shape->add_dim(1); // One timestep.
new_shape->add_dim(this->N_);
new_shape->add_dim(
x_static_transform_param->inner_product_param().num_output());
reshape_param->add_bottom("W_xc_x_static");
reshape_param->add_top("W_xc_x_static");
}
LayerParameter* x_slice_param = net_param->add_layer();
x_slice_param->CopyFrom(slice_param);
x_slice_param->add_bottom("W_xc_x");
x_slice_param->set_name("W_xc_x_slice");
*/
LayerParameter output_concat_layer;
output_concat_layer.set_name("h_concat");
output_concat_layer.set_type("Concat");
output_concat_layer.add_top("h");
output_concat_layer.mutable_concat_param()->set_axis(0);
LayerParameter output_m_layer;
output_m_layer.set_name("m_concat");
output_m_layer.set_type("Concat");
output_m_layer.add_top("mask");
output_m_layer.mutable_concat_param()->set_axis(0); // out put 2
for (int t = 1; t <= this->T_; ++t) {
string tm1s = format_int(t - 1);
string ts = format_int(t);
cont_slice_param->add_top("cont_" + ts);
x_slice_param->add_top("x_" + ts);
// Add a layer to permute x
{
LayerParameter* permute_x_param = net_param->add_layer();
permute_x_param->CopyFrom(permute_param);
permute_x_param->set_name("permute_x_" + ts);
permute_x_param->mutable_permute_param()->add_order(2);
permute_x_param->mutable_permute_param()->add_order(0);
permute_x_param->mutable_permute_param()->add_order(1);
permute_x_param->mutable_permute_param()->add_order(3);