//
// caffe_.cpp provides wrappers of the caffe::Solver class, caffe::Net class,
// caffe::Layer class and caffe::Blob class and some caffe::Caffe functions,
// so that one could easily use Caffe from matlab.
// Note that for matlab, we will simply use float as the data type.
// Internally, data is stored with dimensions reversed from Caffe's:
// e.g., if the Caffe blob axes are (num, channels, height, width),
// the matcaffe data is stored as (width, height, channels, num)
// where width is the fastest dimension.
#include <sstream>
#include <string>
#include <vector>
#include "mex.h"
#include "caffe/caffe.hpp"
#define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs
using namespace caffe; // NOLINT(build/namespaces)
// Do CHECK and throw a Mex error if check fails
inline void mxCHECK(bool expr, const char* msg) {
if (!expr) {
mexErrMsgTxt(msg);
}
}
inline void mxERROR(const char* msg) { mexErrMsgTxt(msg); }
// Check if a file exists and can be opened
void mxCHECK_FILE_EXIST(const char* file) {
std::ifstream f(file);
if (!f.good()) {
f.close();
std::string msg("Could not open file ");
msg += file;
mxERROR(msg.c_str());
}
f.close();
}
// The pointers to caffe::Solver and caffe::Net instances
static vector<shared_ptr<Solver<float> > > solvers_;
static vector<shared_ptr<Net<float> > > nets_;
// init_key is generated at the beginning and everytime you call reset
static double init_key = static_cast<double>(caffe_rng_rand());
/** -----------------------------------------------------------------
** data conversion functions
**/
// Enum indicates which blob memory to use
enum WhichMemory { DATA, DIFF };
// Copy matlab array to Blob data or diff
static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
WhichMemory data_or_diff) {
mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
"number of elements in target blob doesn't match that in input mxArray");
const float* mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
float* blob_mem_ptr = NULL;
switch (Caffe::mode()) {
case Caffe::CPU:
blob_mem_ptr = (data_or_diff == DATA ?
blob->mutable_cpu_data() : blob->mutable_cpu_diff());
break;
case Caffe::GPU:
blob_mem_ptr = (data_or_diff == DATA ?
blob->mutable_gpu_data() : blob->mutable_gpu_diff());
break;
default:
mxERROR("Unknown Caffe mode");
}
caffe_copy(blob->count(), mat_mem_ptr, blob_mem_ptr);
}
// Copy Blob data or diff to matlab array
static mxArray* blob_to_mx_mat(const Blob<float>* blob,
WhichMemory data_or_diff) {
const int num_axes = blob->num_axes();
vector<mwSize> dims(num_axes);
for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
++blob_axis, --mat_axis) {
dims[mat_axis] = static_cast<mwSize>(blob->shape(blob_axis));
}
// matlab array needs to have at least one dimension, convert scalar to 1-dim
if (num_axes == 0) {
dims.push_back(1);
}
mxArray* mx_mat =
mxCreateNumericArray(dims.size(), dims.data(), mxSINGLE_CLASS, mxREAL);
float* mat_mem_ptr = reinterpret_cast<float*>(mxGetData(mx_mat));
const float* blob_mem_ptr = NULL;
switch (Caffe::mode()) {
case Caffe::CPU:
blob_mem_ptr = (data_or_diff == DATA ? blob->cpu_data() : blob->cpu_diff());
break;
case Caffe::GPU:
blob_mem_ptr = (data_or_diff == DATA ? blob->gpu_data() : blob->gpu_diff());
break;
default:
mxERROR("Unknown Caffe mode");
}
caffe_copy(blob->count(), blob_mem_ptr, mat_mem_ptr);
return mx_mat;
}
// Convert vector<int> to matlab row vector
static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
mxArray* mx_vec = mxCreateDoubleMatrix(int_vec.size(), 1, mxREAL);
double* vec_mem_ptr = mxGetPr(mx_vec);
for (int i = 0; i < int_vec.size(); i++) {
vec_mem_ptr[i] = static_cast<double>(int_vec[i]);
}
return mx_vec;
}
// Convert vector<string> to matlab cell vector of strings
static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
for (int i = 0; i < str_vec.size(); i++) {
mxSetCell(mx_strcell, i, mxCreateString(str_vec[i].c_str()));
}
return mx_strcell;
}
/** -----------------------------------------------------------------
** handle and pointer conversion functions
** a handle is a struct array with the following fields
** (uint64) ptr : the pointer to the C++ object
** (double) init_key : caffe initialization key
**/
// Convert a handle in matlab to a pointer in C++. Check if init_key matches
template <typename T>
static T* handle_to_ptr(const mxArray* mx_handle) {
mxArray* mx_ptr = mxGetField(mx_handle, 0, "ptr");
mxArray* mx_init_key = mxGetField(mx_handle, 0, "init_key");
mxCHECK(mxIsUint64(mx_ptr), "pointer type must be uint64");
mxCHECK(mxGetScalar(mx_init_key) == init_key,
"Could not convert handle to pointer due to invalid init_key. "
"The object might have been cleared.");
return reinterpret_cast<T*>(*reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)));
}
// Create a handle struct vector, without setting up each handle in it
template <typename T>
static mxArray* create_handle_vec(int ptr_num) {
const int handle_field_num = 2;
const char* handle_fields[handle_field_num] = { "ptr", "init_key" };
return mxCreateStructMatrix(ptr_num, 1, handle_field_num, handle_fields);
}
// Set up a handle in a handle struct vector by its index
template <typename T>
static void setup_handle(const T* ptr, int index, mxArray* mx_handle_vec) {
mxArray* mx_ptr = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
*reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)) =
reinterpret_cast<uint64_t>(ptr);
mxSetField(mx_handle_vec, index, "ptr", mx_ptr);
mxSetField(mx_handle_vec, index, "init_key", mxCreateDoubleScalar(init_key));
}
// Convert a pointer in C++ to a handle in matlab
template <typename T>
static mxArray* ptr_to_handle(const T* ptr) {
mxArray* mx_handle = create_handle_vec<T>(1);
setup_handle(ptr, 0, mx_handle);
return mx_handle;
}
// Convert a vector of shared_ptr in C++ to handle struct vector
template <typename T>
static mxArray* ptr_vec_to_handle_vec(const vector<shared_ptr<T> >& ptr_vec) {
mxArray* mx_handle_vec = create_handle_vec<T>(ptr_vec.size());
for (int i = 0; i < ptr_vec.size(); i++) {
setup_handle(ptr_vec[i].get(), i, mx_handle_vec);
}
return mx_handle_vec;
}
/** -----------------------------------------------------------------
** matlab command functions: caffe_(api_command, arg1, arg2, ...)
**/
// Usage: caffe_('get_solver', solver_file);
static void get_solver(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
"Usage: caffe_('get_solver', solver_file)");
char* solver_file = mxArrayToString(prhs[0]);
mxCHECK_FILE_EXIST(solver_file);
shared_ptr<Solver<float> > solver(new caffe::SGDSolver<float>(solver_file));
solvers_.push_back(solver);
plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
mxFree(solver_file);
}
// Usage: caffe_('solver_get_attr', hSolver)
static void solver_get_attr(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
"Usage: caffe_('solver_get_attr', hSolver)");
Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
const int solver_attr_num = 2;
const char* solver_attrs[solver_attr_num] = { "hNet_net", "hNet_test_nets" };
mxArray* mx_solver_attr = mxCreateStructMatrix(1, 1, solver_attr_num,
solver_attrs);
mxSetField(mx_solver_attr, 0, "hNet_net",
ptr_to_handle<Net<float> >(solver->net().get()));
mxSetField(mx_solver_attr, 0, "hNet_test_nets",
ptr_vec_to_handle_vec<Net<float> >(solver->test_nets()));
plhs[0] = mx_solver_attr;
}
// Usage: caffe_('solver_ge
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
external.rar (45个子文件)
external
caffe
matlab
caffe_faster_rcnn
hdf5_hl.dll 137KB
libglog.dll 85KB
opencv_imgproc249.dll 2.09MB
gflags.lib 26KB
libprotobuf.lib 28.25MB
opencv_core249.dll 2.42MB
+caffe
set_device.m 261B
Solver.m 2KB
io.m 1KB
imagenet
ilsvrc_2012_mean.mat 593KB
private
CHECK.m 78B
is_valid_handle.m 962B
CHECK_FILE_EXIST.m 125B
caffe_.cpp 20KB
reset_all.m 180B
set_random_seed.m 290B
get_solver.m 308B
get_net.m 1KB
run_tests.m 382B
init_log.m 364B
set_mode_gpu.m 104B
+test
test_solver.m 1KB
test_net.m 4KB
set_mode_cpu.m 104B
Blob.m 3KB
Layer.m 1017B
Net.m 7KB
cudart64_75.dll 352KB
caffe_.mexw64 16.39MB
cudart32_75.dll 285KB
lmdb.lib 523KB
caffe_.lib 2KB
pthreadVC2_x64.lib 29KB
libglog.lib 47KB
pthreadVC2.dll 81KB
hdf5_hl.lib 25KB
opencv_highgui249.dll 2.3MB
leveldb.lib 12.19MB
libiomp5md.dll 1.02MB
hdf5.lib 463KB
caffe_.exp 694B
hdf5.dll 2.61MB
zlib.dll 108KB
gflags.dll 230KB
szip.dll 71KB
共 45 条
- 1
Swordddddd
- 粉丝: 24
- 资源: 1
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
前往页