Skip to content
Snippets Groups Projects
genericio.cpp 9.8 KiB
Newer Older
#include "GenericIO.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <sstream>
#include <string>
#include <vector>
#include <map>
#include <cstdint>
#include <optional>

#ifndef GENERICIO_NO_MPI
#include <mpi.h>
#endif

namespace py = pybind11;

class PyGenericIO : public gio::GenericIO {
public:
  PyGenericIO(const std::string& filename, gio::GenericIO::FileIO method=gio::GenericIO::FileIOPOSIX, gio::GenericIO::MismatchBehavior redistribute=gio::GenericIO::MismatchRedistribute)
#ifdef GENERICIO_NO_MPI
      : gio::GenericIO(filename, method), num_ranks(0) {
#else
      : gio::GenericIO(MPI_COMM_WORLD, filename, method), num_ranks(0) {
#endif
    // open headers and rank info
    openAndReadHeader(redistribute);
    num_ranks = readNRanks();
    // read variable info
    getVariableInfo(variables);
  }

  void inspect() {
    int rank;
  #ifdef GENERICIO_NO_MPI
    rank = 0;
  #else
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  #endif
    if(rank == 0) {
      std::stringstream s;
      s << "Number of Elements: " << readNumElems() << "\n";
      s << "Total number of Elements: " << readTotalNumElems() << "\n";
      s << "[data type] Variable name\n";
      s << "---------------------------------------------\n";
      for (int i = 0; i < variables.size(); ++i) {
        gio::GenericIO::VariableInfo vinfo = variables[i];
        if (vinfo.IsFloat)
          s << "[f";
        else
          s << "[i";
        int NumElements = vinfo.Size / vinfo.ElementSize;
        s << " " << vinfo.ElementSize * 8;
        if (NumElements > 1)
          s << "x" << NumElements;
        s << "] ";
        s << vinfo.Name << "\n";
      }
      s << "\n(i=integer,f=floating point, number bits size)\n";
      py::print(s.str());
    }
  }

  std::map<std::string, py::array> read(std::optional<std::vector<std::string>> var_names) {
    // read number of elements
    int64_t num_elem = readNumElems();
    
    // if no argument, read all
    if(!var_names.has_value()) {
      var_names.emplace(std::vector<std::string>());
      for(const auto& v: variables) {
        var_names->push_back(v.Name);
      }
    }

    clearVariables();
    std::map<std::string, py::array> result;

    for(const std::string& var_name: *var_names) {
      auto varp = std::find_if(
        variables.begin(), 
        variables.end(), 
        [&var_name](const auto& v){ return v.Name == var_name; }
        );
      if (varp != variables.end()) {
        // extra space
        py::ssize_t readsize = num_elem + requestedExtraSpace()/(*varp).ElementSize;
        if((*varp).IsFloat && (*varp).ElementSize == 4) {
          result[var_name] = py::array_t<float>(readsize);
          addVariable(*varp, result[var_name].mutable_data(), gio::GenericIO::VarHasExtraSpace);
        } else if((*varp).IsFloat && (*varp).ElementSize == 8) {
          result[var_name] = py::array_t<double>(readsize);
          addVariable(*varp, result[var_name].mutable_data(), gio::GenericIO::VarHasExtraSpace);
        } else if(!(*varp).IsFloat && (*varp).ElementSize == 4) {
          result[var_name] = py::array_t<int32_t>(readsize);
          addVariable(*varp, result[var_name].mutable_data(), gio::GenericIO::VarHasExtraSpace);
        } else if(!(*varp).IsFloat && (*varp).ElementSize == 8) {
          result[var_name] = py::array_t<int64_t>(readsize);
          addVariable(*varp, result[var_name].mutable_data(), gio::GenericIO::VarHasExtraSpace);
        } else if(!(*varp).IsFloat && (*varp).ElementSize == 2) {
          result[var_name] = py::array_t<uint16_t>(readsize);
          addVariable(*varp, result[var_name].mutable_data(), gio::GenericIO::VarHasExtraSpace);
        }
      }
    }
    
    readData();
  #ifndef GENERICIO_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
  #endif

    // get rid of extraspace
    std::for_each(result.begin(), result.end(), [&](auto& item){ item.second.resize({num_elem}); });

    return result;
  }

  const std::vector<gio::GenericIO::VariableInfo> &get_variables() {
    return variables;
  }

  std::array<double, 3> read_phys_origin() {
    std::array<double, 3> origin;
    readPhysOrigin(origin.data());
    return origin;
  }

  std::array<double, 3> read_phys_scale() {
    std::array<double, 3> scale;
    readPhysScale(scale.data());
    return scale;
  }

private:
  int num_ranks;
  std::vector<gio::GenericIO::VariableInfo> variables;
};

std::map<std::string, py::array> read_genericio(std::string filename, std::optional<std::vector<std::string>> var_names, PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX, PyGenericIO::MismatchBehavior redistribute=PyGenericIO::MismatchBehavior::MismatchRedistribute) {
  PyGenericIO reader(filename, method, redistribute);
  return reader.read(var_names);
}

void inspect_genericio(std::string filename, PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX, PyGenericIO::MismatchBehavior redistribute=PyGenericIO::MismatchBehavior::MismatchRedistribute) {
  PyGenericIO reader(filename, method, redistribute);
  reader.inspect();
}

#ifndef GENERICIO_NO_MPI
void write_genericio(std::string filename, std::map<std::string, py::array> variables, std::array<double, 3> phys_scale, std::array<double, 3> phys_origin, PyGenericIO::FileIO method=PyGenericIO::FileIO::FileIOPOSIX) {
  // check data integrity, find particle count
  int64_t particle_count = -1;
  for(auto const& [name, data]: variables) {
    if(data.ndim() != 1) {
      throw std::runtime_error("dimension of array must be 1 (" + name + ")");
    }
    if(particle_count == -1) {
      particle_count = data.size();
    } else if(particle_count != data.size()) {
      throw std::runtime_error("arrays do not have same length (" + name + ")");
    }
  }

  gio::GenericIO writer(MPI_COMM_WORLD, filename, method);

  writer.setNumElems(particle_count);

  // set size
  for (int d = 0; d < 3; ++d) {
    writer.setPhysOrigin(phys_origin[d], d);
    writer.setPhysScale(phys_scale[d], d);
  }

  for(auto& [name, data]: variables) {
    if(py::isinstance<py::array_t<float>>(data)) 
      writer.addVariable(name.c_str(), reinterpret_cast<float*>(data.mutable_data()));
    else if(py::isinstance<py::array_t<double>>(data)) 
      writer.addVariable(name.c_str(), reinterpret_cast<double*>(data.mutable_data()));
    else if(py::isinstance<py::array_t<int32_t>>(data)) 
      writer.addVariable(name.c_str(), reinterpret_cast<int32_t*>(data.mutable_data()));
    else if(py::isinstance<py::array_t<int64_t>>(data)) 
      writer.addVariable(name.c_str(), reinterpret_cast<int64_t*>(data.mutable_data()));
    else if(py::isinstance<py::array_t<uint16_t>>(data)) 
      writer.addVariable(name.c_str(), reinterpret_cast<uint16_t*>(data.mutable_data()));
    else
      throw std::runtime_error("array dtype not supported for " + name);
  }
  writer.write();
  MPI_Barrier(MPI_COMM_WORLD);

}
#endif

PYBIND11_MODULE(pygio, m) {
  m.doc() = "genericio python module";
#ifndef GENERICIO_NO_MPI
  m.def("_init_mpi", [](){
    int initialized;
    MPI_Initialized(&initialized);
    if(!initialized) {
      int level_provided;
      MPI_Init_thread(nullptr, nullptr, MPI_THREAD_SINGLE, &level_provided); 
    }
#endif

  py::class_<PyGenericIO> pyGenericIO(m, "PyGenericIO");

  py::enum_<PyGenericIO::FileIO>(pyGenericIO, "FileIO")
    .value("FileIOMPI", PyGenericIO::FileIO::FileIOMPI)
    .value("FileIOPOSIX", PyGenericIO::FileIO::FileIOPOSIX)
    .value("FileIOMPICollective", PyGenericIO::FileIO::FileIOMPICollective);

  py::enum_<PyGenericIO::MismatchBehavior>(pyGenericIO, "MismatchBehavior")
    .value("MismatchAllowed", PyGenericIO::MismatchBehavior::MismatchAllowed)
    .value("MismatchDisallowed", PyGenericIO::MismatchBehavior::MismatchDisallowed)
    .value("MismatchRedistribute", PyGenericIO::MismatchBehavior::MismatchRedistribute);

  pyGenericIO.def(py::init<std::string, PyGenericIO::FileIO, PyGenericIO::MismatchBehavior>(), py::arg("filename"), py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX, py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute)
      .def("inspect", &PyGenericIO::inspect, "Print variable infos and size of GenericIO file")
      .def("get_variables", &PyGenericIO::get_variables, "Get a list of VariableInformations defined in the GenericIO file")
      .def("read_num_elems", (size_t (PyGenericIO::*)(int))(&PyGenericIO::readNumElems), py::arg("eff_rank")=-1)
      .def("read_total_num_elems", (uint64_t (PyGenericIO::*)(void))(&PyGenericIO::readTotalNumElems))
      .def("read_phys_origin", &PyGenericIO::read_phys_origin)
      .def("read_phys_scale", &PyGenericIO::read_phys_scale)
      .def("read", &PyGenericIO::read, py::arg("variables")=nullptr);

  py::class_<gio::GenericIO::VariableInfo>(pyGenericIO, "VariableInfo")
      .def_readonly("name", &gio::GenericIO::VariableInfo::Name)
      .def_readonly("size", &gio::GenericIO::VariableInfo::Size)
      .def_readonly("element_size", &gio::GenericIO::VariableInfo::ElementSize)
      .def_readonly("is_float", &gio::GenericIO::VariableInfo::IsFloat)
      .def("__repr__", [](const gio::GenericIO::VariableInfo &vi) {
        return std::string("<PyGenericIO.VariableInfo type=") +
               (vi.IsFloat ? "float" : "int") + " name='" + vi.Name + "'>";
      });
  m.def("read_genericio", &read_genericio, py::arg("filename"), py::arg("variables")=nullptr, py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX, py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute, py::return_value_policy::move);
  m.def("inspect_genericio", &inspect_genericio, py::arg("filename"), py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX, py::arg("redistribute")=PyGenericIO::MismatchBehavior::MismatchRedistribute);
#ifndef GENERICIO_NO_MPI
  m.def("write_genericio", &write_genericio, py::arg("filename"), py::arg("variables"), py::arg("phys_scale"), py::arg("phys_origin") = std::array<double, 3>({0., 0., 0.}), py::arg("method")=PyGenericIO::FileIO::FileIOPOSIX);
#endif