14 #include "MEAL/ProjectGradient.h"
15 #include "MEAL/Composite.h"
16 #include "stringtok.h"
29 class GroupRule :
public T
34 typedef typename T::Result Result;
55 T*
get_model (
unsigned i) {
return model.at(i); }
56 const T*
get_model (
unsigned i)
const {
return model.at(i); }
59 unsigned get_nmodel ()
const {
return model.size(); }
71 void parse (
const std::string& text);
79 std::string class_name()
const
80 {
return "MEAL::GroupRule[" + this->get_name() +
"]::"; }
94 void calculate (Result& result, std::vector<Result>* grad);
100 virtual void operate (Result& total,
const Result& element) = 0;
103 virtual const Result
partial (
const Result& element)
const = 0;
108 std::vector< Project<T> > model;
114 std::vector<Result> gradient;
133 unsigned nmodel = meta.model.size();
134 for (
unsigned imodel=0; imodel < nmodel; imodel++)
135 add_model (meta.model[imodel]);
144 const std::string& sep)
const
146 unsigned nmodel = model.size();
147 for (
unsigned imodel=0; imodel < nmodel; imodel++) {
148 text += sep + model[imodel]->get_name ();
149 model[imodel]->print_parameters (text, sep +
" ");
156 if (model.size())
try
158 Function::parse(line);
161 catch (
Error& error) {
165 std::string temp = line;
166 std::string key = stringtok (temp, WHITESPACE);
168 if (this->get_verbose())
169 std::cerr << class_name() <<
"::parse key '" << key <<
"'" << std::endl;
171 Function* model = Function::factory (key);
173 T* mtype =
dynamic_cast<T*
>(model);
175 throw Error (InvalidParam, class_name()+
"parse",
176 model->
get_name() +
" is not of type " +
177 std::string(T::Name));
188 composite.map (model.back());
190 FunctionPolicyTraits<T>::composite_component(
this, x);
192 if (this->get_verbose())
193 std::cerr << class_name() +
"add_model size=" << model.size() << std::endl;
199 for (
unsigned i=0; i < model.size(); i++)
200 if (model[i].ptr() == x)
202 composite.unmap (model[i]);
203 model.erase( model.begin() + i );
207 throw Error (InvalidState, class_name() +
"remove_model",
"model not found");
220 result = get_identity();
222 for (
unsigned jgrad=0; jgrad<gradient.size(); jgrad++)
223 gradient[jgrad] = get_identity();
228 std::vector<Result>* grad)
230 unsigned nmodel = model.size();
232 if (this->get_verbose())
233 std::cerr << class_name() +
"calculate nmodel=" << nmodel << std::endl;
239 std::vector<Result> comp_gradient;
242 std::vector<Result>* comp_gradient_ptr = 0;
244 unsigned total_nparam = 0;
248 for (
unsigned imodel=0; imodel < nmodel; imodel++)
249 total_nparam += model[imodel]->get_nparam();
251 comp_gradient_ptr = &comp_gradient;
252 gradient.resize (total_nparam);
258 unsigned igradient = 0;
260 for (
unsigned imodel=0; imodel < nmodel; imodel++)
262 if (this->get_verbose())
263 std::cerr << class_name() +
"calculate evaluate "
264 << model[imodel]->get_name() << std::endl;
269 comp_result = model[imodel]->evaluate (comp_gradient_ptr);
271 if (this->get_verbose())
272 std::cerr << class_name() +
"calculate "
273 << model[imodel]->get_name()
274 <<
" result=" << comp_result << std::endl;
276 operate( result, comp_result );
281 unsigned ngrad = comp_gradient_ptr->size();
283 for (jgrad=0; jgrad<igradient; jgrad++)
284 operate( gradient[jgrad], partial(comp_result) );
286 for (jgrad=0; jgrad<ngrad; jgrad++)
287 operate( gradient[igradient + jgrad], (*comp_gradient_ptr)[jgrad] );
289 for (jgrad=igradient+ngrad; jgrad<gradient.size(); jgrad++)
290 operate( gradient[jgrad], partial(comp_result) );
295 error += class_name() +
"calculate";
296 throw error <<
" model=" << model[imodel]->get_name();
301 if (model[imodel]->get_nparam() != comp_gradient.size())
302 throw Error (InvalidState, class_name() +
"calculate",
303 "model[%d]=%s.get_nparam=%d != gradient.size=%d",
304 imodel, model[imodel]->get_name().c_str(),
305 model[imodel]->get_nparam(), comp_gradient.size());
307 igradient += comp_gradient.size();
319 if (igradient != total_nparam)
320 throw Error (InvalidState, (class_name() +
"calculate").c_str(),
321 "after calculation igrad=%d != total_nparam=%d",
322 igradient, total_nparam);
324 grad->resize (this->get_nparam());
327 ProjectGradient (model, gradient, *grad);
330 if (this->get_verbose())
332 std::cerr << class_name() +
"calculate result\n " << retval << std::endl;
335 std::cerr << class_name() +
"calculate gradient" << std::endl;
336 for (
unsigned i=0; i<grad->size(); i++)
337 std::cerr <<
" " << i <<
":" << this->get_infit(i)
338 <<
"=" << (*grad)[i] << std::endl;